torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -3,8 +3,8 @@ from functools import partial
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
6
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
8
|
from ..functional import (
|
|
9
9
|
debias, debiased_step_size,
|
|
10
10
|
ema_,
|
|
@@ -27,24 +27,25 @@ def adam_(
|
|
|
27
27
|
pow: float = 2,
|
|
28
28
|
debiased: bool = True,
|
|
29
29
|
max_exp_avg_sq_: TensorList | None = None,
|
|
30
|
-
|
|
30
|
+
|
|
31
|
+
# inner args
|
|
32
|
+
inner: Module | None = None,
|
|
33
|
+
params: list[torch.Tensor] | None = None,
|
|
34
|
+
grads: list[torch.Tensor] | None = None,
|
|
31
35
|
):
|
|
32
36
|
"""Returns new tensors or updates params in-place."""
|
|
33
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
34
|
-
|
|
35
37
|
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
36
38
|
debiased=False,step=step,pow=pow)
|
|
37
39
|
|
|
38
|
-
if
|
|
40
|
+
if inner is not None:
|
|
41
|
+
assert params is not None
|
|
42
|
+
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
39
43
|
|
|
40
|
-
|
|
41
|
-
if
|
|
42
|
-
|
|
43
|
-
# update params in-place
|
|
44
|
-
params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
|
|
45
|
-
return None
|
|
44
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
45
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
46
|
+
return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
|
|
46
47
|
|
|
47
|
-
class Adam(
|
|
48
|
+
class Adam(Transform):
|
|
48
49
|
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
|
|
49
50
|
pytorch in that debiasing is applied after adding epsilon.
|
|
50
51
|
|
|
@@ -66,36 +67,29 @@ class Adam(Module):
|
|
|
66
67
|
alpha: float = 1.,
|
|
67
68
|
pow: float = 2,
|
|
68
69
|
debiased: bool = True,
|
|
70
|
+
inner: Chainable | None = None
|
|
69
71
|
):
|
|
70
72
|
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
71
|
-
super().__init__(defaults)
|
|
72
|
-
|
|
73
|
+
super().__init__(defaults, uses_grad=False)
|
|
74
|
+
|
|
75
|
+
if inner is not None: self.set_child('inner', inner)
|
|
73
76
|
|
|
74
77
|
@torch.no_grad
|
|
75
|
-
def
|
|
78
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
76
79
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
77
80
|
|
|
78
|
-
beta1,beta2,eps,alpha=
|
|
79
|
-
amsgrad,pow,debiased =
|
|
81
|
+
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
82
|
+
amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
|
|
80
83
|
|
|
81
84
|
if amsgrad:
|
|
82
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
85
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
83
86
|
else:
|
|
84
|
-
exp_avg, exp_avg_sq =
|
|
87
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
85
88
|
max_exp_avg_sq = None
|
|
86
89
|
|
|
87
|
-
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
88
|
-
if vars.is_last:
|
|
89
|
-
if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
|
|
90
|
-
passed_params = TensorList(vars.params)
|
|
91
|
-
vars.stop = True
|
|
92
|
-
vars.skip_update = True
|
|
93
|
-
|
|
94
|
-
else:
|
|
95
|
-
passed_params = None
|
|
96
90
|
|
|
97
|
-
|
|
98
|
-
tensors=TensorList(
|
|
91
|
+
return adam_(
|
|
92
|
+
tensors=TensorList(tensors),
|
|
99
93
|
exp_avg_=exp_avg,
|
|
100
94
|
exp_avg_sq_=exp_avg_sq,
|
|
101
95
|
alpha=alpha,
|
|
@@ -106,7 +100,10 @@ class Adam(Module):
|
|
|
106
100
|
pow=pow,
|
|
107
101
|
debiased=debiased,
|
|
108
102
|
max_exp_avg_sq_=max_exp_avg_sq,
|
|
109
|
-
params_=passed_params,
|
|
110
|
-
)
|
|
111
103
|
|
|
112
|
-
|
|
104
|
+
# inner args
|
|
105
|
+
inner=self.children.get("inner", None),
|
|
106
|
+
params=params,
|
|
107
|
+
grads=grads,
|
|
108
|
+
|
|
109
|
+
)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from ...core import Module, Target, Transform
|
|
4
|
-
from ...utils import NumberList, TensorList
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
|
|
@@ -28,8 +28,8 @@ class Lion(Transform):
|
|
|
28
28
|
super().__init__(defaults, uses_grad=False)
|
|
29
29
|
|
|
30
30
|
@torch.no_grad
|
|
31
|
-
def
|
|
32
|
-
beta1, beta2 =
|
|
33
|
-
exp_avg =
|
|
31
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
32
|
+
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
33
|
+
exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
34
34
|
return lion_(TensorList(tensors),exp_avg,beta1,beta2)
|
|
35
35
|
|
|
@@ -164,7 +164,7 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
164
164
|
method (str, optional):
|
|
165
165
|
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
166
166
|
target (str, optional):
|
|
167
|
-
what to set on
|
|
167
|
+
what to set on var.
|
|
168
168
|
"""
|
|
169
169
|
def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
|
|
170
170
|
method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
|
|
@@ -172,9 +172,9 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
172
172
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
173
173
|
|
|
174
174
|
@torch.no_grad
|
|
175
|
-
def
|
|
175
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
176
176
|
orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
|
|
177
|
-
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(
|
|
177
|
+
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(settings)
|
|
178
178
|
|
|
179
179
|
if not orthogonalize: return tensor
|
|
180
180
|
|
|
@@ -199,7 +199,7 @@ class DualNormCorrection(TensorwiseTransform):
|
|
|
199
199
|
def __init__(self, target: Target='update'):
|
|
200
200
|
super().__init__({}, uses_grad=True, target=target)
|
|
201
201
|
|
|
202
|
-
def
|
|
202
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
203
203
|
assert grad is not None
|
|
204
204
|
if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
|
|
205
205
|
return _dual_norm_correction(tensor, grad, batch_first=False)
|
|
@@ -213,8 +213,8 @@ class MuonAdjustLR(Transform):
|
|
|
213
213
|
defaults = dict(alpha=alpha)
|
|
214
214
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
215
215
|
|
|
216
|
-
def
|
|
217
|
-
alphas =
|
|
216
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
217
|
+
alphas = [s['alpha'] for s in settings]
|
|
218
218
|
tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
|
|
219
219
|
tensors = [i[0] for i in tensors_alphas]
|
|
220
220
|
a = [i[1] for i in alphas]
|
|
@@ -30,16 +30,15 @@ class OrthoGrad(Transform):
|
|
|
30
30
|
Args:
|
|
31
31
|
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
32
32
|
renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
|
|
33
|
-
target (Target, optional): what to set on
|
|
33
|
+
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
34
34
|
"""
|
|
35
35
|
def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
|
|
36
36
|
defaults = dict(eps=eps, renormalize=renormalize)
|
|
37
37
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
38
38
|
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
renormalize = settings['renormalize']
|
|
39
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
40
|
+
eps = settings[0]['eps']
|
|
41
|
+
renormalize = settings[0]['renormalize']
|
|
43
42
|
|
|
44
43
|
params = as_tensorlist(params)
|
|
45
44
|
target = as_tensorlist(tensors)
|
|
@@ -3,8 +3,8 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Module, Target, Transform, Chainable,
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
6
|
+
from ...core import Module, Target, Transform, Chainable, Var, apply_transform
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
8
|
from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
|
|
9
9
|
|
|
10
10
|
|
|
@@ -23,7 +23,6 @@ def rmsprop_(
|
|
|
23
23
|
inner: Module | None = None,
|
|
24
24
|
params: list[torch.Tensor] | None = None,
|
|
25
25
|
grads: list[torch.Tensor] | None = None,
|
|
26
|
-
vars: Vars | None = None,
|
|
27
26
|
):
|
|
28
27
|
"""returns `tensors_`"""
|
|
29
28
|
if exp_avg_ is not None:
|
|
@@ -36,7 +35,7 @@ def rmsprop_(
|
|
|
36
35
|
|
|
37
36
|
if inner is not None:
|
|
38
37
|
assert params is not None
|
|
39
|
-
tensors_ = TensorList(
|
|
38
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
40
39
|
|
|
41
40
|
return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
|
|
42
41
|
|
|
@@ -66,21 +65,20 @@ class RMSprop(Transform):
|
|
|
66
65
|
):
|
|
67
66
|
defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
|
|
68
67
|
super().__init__(defaults=defaults, uses_grad=False)
|
|
69
|
-
|
|
68
|
+
|
|
70
69
|
if inner is not None:
|
|
71
70
|
self.set_child('inner', inner)
|
|
72
71
|
|
|
73
|
-
def
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
centered,debiased,amsgrad,pow,init = itemgetter('centered','debiased','amsgrad','pow','init')(self.settings[params[0]])
|
|
72
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
73
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
74
|
+
smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
|
|
75
|
+
centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
|
|
78
76
|
|
|
79
|
-
exp_avg_sq =
|
|
80
|
-
exp_avg =
|
|
81
|
-
max_exp_avg_sq =
|
|
77
|
+
exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
|
|
78
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList) if centered else None
|
|
79
|
+
max_exp_avg_sq = unpack_states(states, tensors, 'max_exp_avg_sq', cls=TensorList) if amsgrad else None
|
|
82
80
|
|
|
83
|
-
if init == 'update' and
|
|
81
|
+
if init == 'update' and step == 1:
|
|
84
82
|
exp_avg_sq.set_([t**2 for t in tensors])
|
|
85
83
|
if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
|
|
86
84
|
|
|
@@ -90,7 +88,7 @@ class RMSprop(Transform):
|
|
|
90
88
|
smoothing=smoothing,
|
|
91
89
|
eps=eps,
|
|
92
90
|
debiased=debiased,
|
|
93
|
-
step=
|
|
91
|
+
step=step,
|
|
94
92
|
exp_avg_=exp_avg,
|
|
95
93
|
max_exp_avg_sq_=max_exp_avg_sq,
|
|
96
94
|
pow=pow,
|
|
@@ -99,5 +97,4 @@ class RMSprop(Transform):
|
|
|
99
97
|
inner=self.children.get("inner", None),
|
|
100
98
|
params=params,
|
|
101
99
|
grads=grads,
|
|
102
|
-
vars=vars,
|
|
103
100
|
)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
4
|
from ...core import Module, Target, Transform
|
|
5
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
5
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def _bool_ones_like(x):
|
|
@@ -161,20 +161,22 @@ class Rprop(Transform):
|
|
|
161
161
|
alpha: float = 1,
|
|
162
162
|
):
|
|
163
163
|
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
|
|
164
|
-
self.current_step = 0
|
|
165
164
|
super().__init__(defaults, uses_grad=False)
|
|
166
165
|
|
|
167
166
|
@torch.no_grad
|
|
168
|
-
def
|
|
169
|
-
|
|
170
|
-
|
|
167
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
168
|
+
step = self.global_state.get('step', 0)
|
|
169
|
+
self.global_state['step'] = step + 1
|
|
170
|
+
|
|
171
|
+
nplus, nminus, lb, ub, alpha = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', 'alpha', cls=NumberList)
|
|
172
|
+
prev, allowed, magnitudes = unpack_states(
|
|
173
|
+
states, tensors,
|
|
171
174
|
'prev','allowed','magnitudes',
|
|
172
|
-
params=params,
|
|
173
175
|
init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
|
|
174
176
|
cls = TensorList,
|
|
175
177
|
)
|
|
176
178
|
|
|
177
|
-
|
|
179
|
+
tensors = rprop_(
|
|
178
180
|
tensors_ = as_tensorlist(tensors),
|
|
179
181
|
prev_ = prev,
|
|
180
182
|
allowed_ = allowed,
|
|
@@ -184,12 +186,11 @@ class Rprop(Transform):
|
|
|
184
186
|
lb = lb,
|
|
185
187
|
ub = ub,
|
|
186
188
|
alpha = alpha,
|
|
187
|
-
backtrack=
|
|
188
|
-
step=
|
|
189
|
+
backtrack=settings[0]['backtrack'],
|
|
190
|
+
step=step,
|
|
189
191
|
)
|
|
190
192
|
|
|
191
|
-
|
|
192
|
-
return target
|
|
193
|
+
return tensors
|
|
193
194
|
|
|
194
195
|
|
|
195
196
|
class ScaleLRBySignChange(Transform):
|
|
@@ -220,23 +221,25 @@ class ScaleLRBySignChange(Transform):
|
|
|
220
221
|
):
|
|
221
222
|
defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
|
|
222
223
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
223
|
-
self.current_step = 0
|
|
224
224
|
|
|
225
225
|
@torch.no_grad
|
|
226
|
-
def
|
|
227
|
-
|
|
228
|
-
|
|
226
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
227
|
+
step = self.global_state.get('step', 0)
|
|
228
|
+
self.global_state['step'] = step + 1
|
|
229
|
+
|
|
230
|
+
tensors = as_tensorlist(tensors)
|
|
231
|
+
use_grad = settings[0]['use_grad']
|
|
229
232
|
if use_grad: cur = as_tensorlist(grads)
|
|
230
|
-
else: cur =
|
|
233
|
+
else: cur = tensors
|
|
231
234
|
|
|
232
|
-
nplus, nminus, lb, ub =
|
|
233
|
-
prev, lrs =
|
|
235
|
+
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
236
|
+
prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
|
|
234
237
|
|
|
235
|
-
if
|
|
236
|
-
lrs.set_(
|
|
238
|
+
if step == 0:
|
|
239
|
+
lrs.set_(tensors.full_like([s['alpha'] for s in settings]))
|
|
237
240
|
|
|
238
|
-
|
|
239
|
-
tensors_ =
|
|
241
|
+
tensors = scale_by_sign_change_(
|
|
242
|
+
tensors_ = tensors,
|
|
240
243
|
cur = cur,
|
|
241
244
|
prev_ = prev,
|
|
242
245
|
lrs_ = lrs,
|
|
@@ -244,10 +247,9 @@ class ScaleLRBySignChange(Transform):
|
|
|
244
247
|
nminus = nminus,
|
|
245
248
|
lb = lb,
|
|
246
249
|
ub = ub,
|
|
247
|
-
step =
|
|
250
|
+
step = step,
|
|
248
251
|
)
|
|
249
|
-
|
|
250
|
-
return target
|
|
252
|
+
return tensors
|
|
251
253
|
|
|
252
254
|
class BacktrackOnSignChange(Transform):
|
|
253
255
|
"""Negates or undoes update for parameters where where gradient or update sign changes.
|
|
@@ -268,28 +270,28 @@ class BacktrackOnSignChange(Transform):
|
|
|
268
270
|
def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
|
|
269
271
|
defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
|
|
270
272
|
super().__init__(defaults, uses_grad=use_grad)
|
|
271
|
-
self.current_step = 0
|
|
272
273
|
|
|
273
274
|
@torch.no_grad
|
|
274
|
-
def
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
275
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
276
|
+
step = self.global_state.get('step', 0)
|
|
277
|
+
self.global_state['step'] = step + 1
|
|
278
|
+
|
|
279
|
+
tensors = as_tensorlist(tensors)
|
|
280
|
+
use_grad = settings[0]['use_grad']
|
|
281
|
+
backtrack = settings[0]['backtrack']
|
|
279
282
|
|
|
280
283
|
if use_grad: cur = as_tensorlist(grads)
|
|
281
|
-
else: cur =
|
|
284
|
+
else: cur = tensors
|
|
282
285
|
|
|
283
|
-
|
|
284
|
-
tensors_ =
|
|
286
|
+
tensors = backtrack_on_sign_change_(
|
|
287
|
+
tensors_ = tensors,
|
|
285
288
|
cur = cur,
|
|
286
|
-
prev_ =
|
|
289
|
+
prev_ = unpack_states(states, tensors, 'prev', cls=TensorList),
|
|
287
290
|
backtrack = backtrack,
|
|
288
|
-
step =
|
|
291
|
+
step = step,
|
|
289
292
|
)
|
|
290
293
|
|
|
291
|
-
|
|
292
|
-
return target
|
|
294
|
+
return tensors
|
|
293
295
|
|
|
294
296
|
class SignConsistencyMask(Transform):
|
|
295
297
|
"""0 if sign changed 1 otherwise"""
|
|
@@ -297,10 +299,10 @@ class SignConsistencyMask(Transform):
|
|
|
297
299
|
super().__init__({}, uses_grad=False, target = target)
|
|
298
300
|
|
|
299
301
|
@torch.no_grad
|
|
300
|
-
def
|
|
301
|
-
prev =
|
|
302
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
303
|
+
prev = unpack_states(states, tensors, 'prev', cls=TensorList)
|
|
302
304
|
mask = prev.mul_(tensors).gt_(0)
|
|
303
|
-
prev.
|
|
305
|
+
prev.copy_(tensors)
|
|
304
306
|
return mask
|
|
305
307
|
|
|
306
308
|
|
|
@@ -317,16 +319,18 @@ class SignConsistencyLRs(Transform):
|
|
|
317
319
|
):
|
|
318
320
|
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
|
|
319
321
|
super().__init__(defaults, uses_grad=False, target = target)
|
|
320
|
-
self.current_step = 0
|
|
321
322
|
|
|
322
323
|
@torch.no_grad
|
|
323
|
-
def
|
|
324
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
325
|
+
step = self.global_state.get('step', 0)
|
|
326
|
+
self.global_state['step'] = step + 1
|
|
327
|
+
|
|
324
328
|
target = as_tensorlist(tensors)
|
|
325
|
-
nplus, nminus, lb, ub =
|
|
326
|
-
prev, lrs =
|
|
329
|
+
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
330
|
+
prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
|
|
327
331
|
|
|
328
|
-
if
|
|
329
|
-
lrs.set_(target.full_like(
|
|
332
|
+
if step == 0:
|
|
333
|
+
lrs.set_(target.full_like([s['alpha'] for s in settings]))
|
|
330
334
|
|
|
331
335
|
target = sign_consistency_lrs_(
|
|
332
336
|
tensors = target,
|
|
@@ -336,7 +340,6 @@ class SignConsistencyLRs(Transform):
|
|
|
336
340
|
nminus = nminus,
|
|
337
341
|
lb = lb,
|
|
338
342
|
ub = ub,
|
|
339
|
-
step =
|
|
343
|
+
step = step,
|
|
340
344
|
)
|
|
341
|
-
self.current_step += 1
|
|
342
345
|
return target.clone()
|
|
@@ -4,7 +4,7 @@ from functools import partial
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable, Transform,
|
|
7
|
+
from ...core import Chainable, Transform, apply_transform
|
|
8
8
|
from ...utils.linalg import matrix_power_eigh
|
|
9
9
|
from ...utils import set_storage_
|
|
10
10
|
|
|
@@ -106,7 +106,6 @@ class Shampoo(Transform):
|
|
|
106
106
|
self,
|
|
107
107
|
decay: float | None = None,
|
|
108
108
|
beta: float | None = None,
|
|
109
|
-
reg: float = 1e-6,
|
|
110
109
|
update_freq: int = 10,
|
|
111
110
|
exp_override: int | None = None,
|
|
112
111
|
merge_small: bool = True,
|
|
@@ -115,25 +114,24 @@ class Shampoo(Transform):
|
|
|
115
114
|
adagrad_eps: float = 1e-8,
|
|
116
115
|
inner: Chainable | None = None,
|
|
117
116
|
):
|
|
118
|
-
defaults = dict(decay=decay, beta=beta,
|
|
117
|
+
defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
|
|
119
118
|
super().__init__(defaults, uses_grad=False)
|
|
120
119
|
|
|
121
120
|
if inner is not None:
|
|
122
121
|
self.set_child('inner', inner)
|
|
123
122
|
|
|
124
|
-
def
|
|
125
|
-
|
|
123
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
124
|
+
merged_tensors = [] # target with merged dims
|
|
126
125
|
|
|
127
126
|
# update preconditioners
|
|
128
|
-
for i,(
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
beta, reg, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
|
|
132
|
-
'beta', 'reg', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(settings)
|
|
127
|
+
for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
|
|
128
|
+
beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
|
|
129
|
+
'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
|
|
133
130
|
|
|
134
131
|
if merge_small:
|
|
135
132
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
136
|
-
|
|
133
|
+
|
|
134
|
+
merged_tensors.append(t)
|
|
137
135
|
|
|
138
136
|
# initialize accumulators and preconditioners for each dim on 1st step
|
|
139
137
|
if 'accumulators' not in state:
|
|
@@ -167,22 +165,18 @@ class Shampoo(Transform):
|
|
|
167
165
|
|
|
168
166
|
# inner step
|
|
169
167
|
if 'inner' in self.children:
|
|
170
|
-
tensors =
|
|
168
|
+
tensors = apply_transform(self.children['inner'], tensors, params=params, grads=grads)
|
|
171
169
|
|
|
172
170
|
# have to merge small dims again
|
|
173
|
-
|
|
174
|
-
for i,(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, settings['max_dim'])
|
|
179
|
-
merged_target.append(t)
|
|
171
|
+
merged_tensors = [] # target with merged dims
|
|
172
|
+
for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
|
|
173
|
+
if setting['merge_small']:
|
|
174
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, setting['max_dim'])
|
|
175
|
+
merged_tensors.append(t)
|
|
180
176
|
|
|
181
177
|
# precondition
|
|
182
|
-
for i,
|
|
183
|
-
|
|
184
|
-
settings = self.settings[p]
|
|
185
|
-
decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(settings)
|
|
178
|
+
for i,(t,state, setting) in enumerate(zip(merged_tensors, states, settings)):
|
|
179
|
+
decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(setting)
|
|
186
180
|
|
|
187
181
|
if 'diagonal_accumulator' in state:
|
|
188
182
|
tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
|
|
@@ -2,7 +2,7 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Transform,
|
|
5
|
+
from ...core import Chainable, Transform, apply_transform
|
|
6
6
|
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
7
|
|
|
8
8
|
@torch.no_grad
|
|
@@ -152,9 +152,8 @@ class SOAP(Transform):
|
|
|
152
152
|
epsilon for dividing first momentum by second. Defaults to 1e-8.
|
|
153
153
|
decay (float | None, optional):
|
|
154
154
|
Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
results but True usually works better. Defaults to True.
|
|
155
|
+
alpha (float, optional):
|
|
156
|
+
learning rate. Defaults to 1.
|
|
158
157
|
bias_correction (bool, optional):
|
|
159
158
|
enables adam bias correction. Defaults to True.
|
|
160
159
|
"""
|
|
@@ -170,7 +169,6 @@ class SOAP(Transform):
|
|
|
170
169
|
eps: float = 1e-8,
|
|
171
170
|
decay: float | None = None,
|
|
172
171
|
alpha: float = 1,
|
|
173
|
-
unprojected_exp_avg: bool = True,
|
|
174
172
|
bias_correction: bool = True,
|
|
175
173
|
):
|
|
176
174
|
defaults = dict(
|
|
@@ -183,21 +181,18 @@ class SOAP(Transform):
|
|
|
183
181
|
precondition_1d=precondition_1d,
|
|
184
182
|
eps=eps,
|
|
185
183
|
decay=decay,
|
|
186
|
-
unprojected_exp_avg=unprojected_exp_avg,
|
|
187
184
|
bias_correction=bias_correction,
|
|
188
185
|
alpha=alpha,
|
|
189
186
|
)
|
|
190
187
|
super().__init__(defaults, uses_grad=False)
|
|
191
188
|
|
|
192
189
|
@torch.no_grad
|
|
193
|
-
def
|
|
190
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
194
191
|
updates = []
|
|
195
192
|
# update preconditioners
|
|
196
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
200
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
|
|
193
|
+
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
194
|
+
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps,alpha = itemgetter(
|
|
195
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps','alpha')(setting)
|
|
201
196
|
|
|
202
197
|
if merge_small:
|
|
203
198
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -237,10 +232,7 @@ class SOAP(Transform):
|
|
|
237
232
|
exp_avg: torch.Tensor = state["exp_avg"]
|
|
238
233
|
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
239
234
|
|
|
240
|
-
|
|
241
|
-
exp_avg.lerp_(t, 1-beta1)
|
|
242
|
-
else:
|
|
243
|
-
exp_avg.lerp_(t_projected, 1-beta1)
|
|
235
|
+
exp_avg.lerp_(t, 1-beta1)
|
|
244
236
|
|
|
245
237
|
if t_projected is None:
|
|
246
238
|
exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
@@ -249,7 +241,7 @@ class SOAP(Transform):
|
|
|
249
241
|
|
|
250
242
|
# project exponential moving averages if they are accumulated unprojected
|
|
251
243
|
exp_avg_projected = exp_avg
|
|
252
|
-
if
|
|
244
|
+
if t_projected is not None:
|
|
253
245
|
exp_avg_projected = project(exp_avg, state['Q'])
|
|
254
246
|
|
|
255
247
|
exp_avg_sq_projected = exp_avg_sq
|
|
@@ -260,10 +252,11 @@ class SOAP(Transform):
|
|
|
260
252
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
261
253
|
# to the original space
|
|
262
254
|
update = exp_avg_projected / denom
|
|
255
|
+
|
|
263
256
|
if t_projected is not None:
|
|
264
257
|
update = project_back(update, state["Q"])
|
|
265
258
|
|
|
266
|
-
if
|
|
259
|
+
if setting['bias_correction']:
|
|
267
260
|
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
268
261
|
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
269
262
|
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
@@ -279,7 +272,7 @@ class SOAP(Transform):
|
|
|
279
272
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
280
273
|
if state['GG'] is not None:
|
|
281
274
|
update_soap_covariances_(t, state['GG'], shampoo_beta)
|
|
282
|
-
if state['step'] %
|
|
275
|
+
if state['step'] % setting['precond_freq'] == 0:
|
|
283
276
|
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
284
277
|
|
|
285
278
|
return updates
|