torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ..line_search import LineSearchBase
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class AdaptiveStepSize(LineSearchBase):
|
|
9
|
-
"""Basic first order step size adaptation method. Re-evaluates the function after stepping, if value decreased sufficiently,
|
|
10
|
-
step size is increased. If value increased, step size is decreased.
|
|
11
|
-
|
|
12
|
-
.. note::
|
|
13
|
-
This works well in some cases, but it is often prone to collapsing.
|
|
14
|
-
For a more robust alternative use :code:`tz.m.AdaptiveBacktracking`.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
|
|
18
|
-
nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
|
|
19
|
-
c (float, optional): descent condition. Defaults to 1e-4.
|
|
20
|
-
init (float, optional): initial step size. Defaults to 1.
|
|
21
|
-
backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
|
|
22
|
-
adaptive (bool, optional):
|
|
23
|
-
If enabled, when multiple consecutive steps have been successful or unsuccessful,
|
|
24
|
-
the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
Examples:
|
|
28
|
-
Adagrad with trust region:
|
|
29
|
-
|
|
30
|
-
.. code-block:: python
|
|
31
|
-
|
|
32
|
-
opt = tz.Modular(
|
|
33
|
-
model.parameters(),
|
|
34
|
-
tz.m.Adagrad(),
|
|
35
|
-
tz.m.TrustRegion()
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
"""
|
|
39
|
-
def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
|
|
40
|
-
defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
|
|
41
|
-
super().__init__(defaults)
|
|
42
|
-
|
|
43
|
-
@torch.no_grad
|
|
44
|
-
def search(self, update, var):
|
|
45
|
-
|
|
46
|
-
nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
|
|
47
|
-
step_size = self.global_state.setdefault('step_size', init)
|
|
48
|
-
previous_success = self.global_state.setdefault('previous_success', False)
|
|
49
|
-
nplus_mul = self.global_state.setdefault('nplus_mul', 1)
|
|
50
|
-
nminus_mul = self.global_state.setdefault('nminus_mul', 1)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
f_0 = self.evaluate_step_size(0, var, backward=False)
|
|
54
|
-
|
|
55
|
-
# directional derivative (0 if c = 0 because it is not needed)
|
|
56
|
-
if c == 0: d = 0
|
|
57
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
|
|
58
|
-
|
|
59
|
-
# test step size
|
|
60
|
-
sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
|
|
61
|
-
|
|
62
|
-
f_1 = self.evaluate_step_size(step_size, var, backward=False)
|
|
63
|
-
|
|
64
|
-
proposed = step_size
|
|
65
|
-
|
|
66
|
-
# very good step
|
|
67
|
-
if f_1 < sufficient_f:
|
|
68
|
-
self.global_state['step_size'] *= nplus * nplus_mul
|
|
69
|
-
|
|
70
|
-
# two very good steps in a row - increase nplus_mul
|
|
71
|
-
if adaptive:
|
|
72
|
-
if previous_success: self.global_state['nplus_mul'] *= nplus
|
|
73
|
-
else: self.global_state['nplus_mul'] = 1
|
|
74
|
-
|
|
75
|
-
# acceptable step step
|
|
76
|
-
#elif f_1 <= f_0: pass
|
|
77
|
-
|
|
78
|
-
# bad step
|
|
79
|
-
if f_1 >= f_0:
|
|
80
|
-
self.global_state['step_size'] *= nminus * nminus_mul
|
|
81
|
-
|
|
82
|
-
# two bad steps in a row - decrease nminus_mul
|
|
83
|
-
if adaptive:
|
|
84
|
-
if previous_success: self.global_state['nminus_mul'] *= nminus
|
|
85
|
-
else: self.global_state['nminus_mul'] = 1
|
|
86
|
-
|
|
87
|
-
if backtrack: proposed = 0
|
|
88
|
-
else: proposed *= nminus * nminus_mul
|
|
89
|
-
|
|
90
|
-
return proposed
|
|
@@ -1,177 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, Transform
|
|
6
|
-
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
-
from ..optimizers.soap import (
|
|
8
|
-
get_orthogonal_matrix,
|
|
9
|
-
get_orthogonal_matrix_QR,
|
|
10
|
-
project,
|
|
11
|
-
project_back,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@torch.no_grad
|
|
16
|
-
def update_adasoap_covariances_(
|
|
17
|
-
grad: torch.Tensor,
|
|
18
|
-
GGs_: list[torch.Tensor | None],
|
|
19
|
-
GG_sqs: list[torch.Tensor | None],
|
|
20
|
-
beta: float | None,
|
|
21
|
-
precond_beta: float | None,
|
|
22
|
-
):
|
|
23
|
-
for i, (GG, GG_sq) in enumerate(zip(GGs_, GG_sqs)):
|
|
24
|
-
if GG is None: continue
|
|
25
|
-
assert GG_sq is not None
|
|
26
|
-
|
|
27
|
-
if precond_beta is None: GG_sq.addcmul_(GG, GG)
|
|
28
|
-
else: GG_sq.mul_(precond_beta).addcmul_(GG, GG, value=1-precond_beta)
|
|
29
|
-
|
|
30
|
-
axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
|
|
31
|
-
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
32
|
-
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class AdaSOAP(Transform):
|
|
36
|
-
"""SOAP with diagonally preconditioned GG^Ts.
|
|
37
|
-
|
|
38
|
-
.. warning::
|
|
39
|
-
Experimental.
|
|
40
|
-
|
|
41
|
-
precond_beta - beta for GG^T squares
|
|
42
|
-
|
|
43
|
-
Verdict: It works, but it is about the same performance as Adam, but maybe more tuning potential?
|
|
44
|
-
"""
|
|
45
|
-
def __init__(
|
|
46
|
-
self,
|
|
47
|
-
beta1: float = 0.95,
|
|
48
|
-
beta2: float = 0.95,
|
|
49
|
-
shampoo_beta: float | None = 0.95,
|
|
50
|
-
precond_beta: float | None = 0.95,
|
|
51
|
-
precond_freq: int = 10,
|
|
52
|
-
merge_small: bool = True,
|
|
53
|
-
max_dim: int = 2_000,
|
|
54
|
-
precondition_1d: bool = True,
|
|
55
|
-
eps: float = 1e-8,
|
|
56
|
-
decay: float | None = None,
|
|
57
|
-
alpha: float = 1,
|
|
58
|
-
unprojected_exp_avg: bool = True,
|
|
59
|
-
bias_correction: bool = True,
|
|
60
|
-
):
|
|
61
|
-
defaults = dict(
|
|
62
|
-
beta1=beta1,
|
|
63
|
-
beta2=beta2,
|
|
64
|
-
shampoo_beta=shampoo_beta,
|
|
65
|
-
precond_beta=precond_beta,
|
|
66
|
-
precond_freq=precond_freq,
|
|
67
|
-
merge_small=merge_small,
|
|
68
|
-
max_dim=max_dim,
|
|
69
|
-
precondition_1d=precondition_1d,
|
|
70
|
-
eps=eps,
|
|
71
|
-
decay=decay,
|
|
72
|
-
unprojected_exp_avg=unprojected_exp_avg,
|
|
73
|
-
bias_correction=bias_correction,
|
|
74
|
-
alpha=alpha,
|
|
75
|
-
)
|
|
76
|
-
super().__init__(defaults, uses_grad=False)
|
|
77
|
-
|
|
78
|
-
@torch.no_grad
|
|
79
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
80
|
-
updates = []
|
|
81
|
-
# update preconditioners
|
|
82
|
-
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
83
|
-
|
|
84
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
85
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
|
|
86
|
-
precond_beta = setting['precond_beta']
|
|
87
|
-
|
|
88
|
-
if merge_small:
|
|
89
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
90
|
-
|
|
91
|
-
# initialize state on 1st step
|
|
92
|
-
if 'GG' not in state:
|
|
93
|
-
state["exp_avg"] = torch.zeros_like(t)
|
|
94
|
-
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
95
|
-
|
|
96
|
-
if not precondition_1d and t.ndim <= 1:
|
|
97
|
-
state['GG'] = []
|
|
98
|
-
state['GG_sq'] = []
|
|
99
|
-
|
|
100
|
-
else:
|
|
101
|
-
state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
102
|
-
state['GG_sq'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
103
|
-
|
|
104
|
-
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
105
|
-
if len([i is not None for i in state['GG']]) == 0:
|
|
106
|
-
state['GG'] = None
|
|
107
|
-
state['GG_sq'] = None
|
|
108
|
-
|
|
109
|
-
if state['GG'] is not None:
|
|
110
|
-
assert state['GG_sq'] is not None
|
|
111
|
-
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
112
|
-
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
113
|
-
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
114
|
-
|
|
115
|
-
state['step'] = 0
|
|
116
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
117
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
118
|
-
# that can mess with other modules scaling
|
|
119
|
-
|
|
120
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
121
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
122
|
-
t_projected = None
|
|
123
|
-
if state['GG'] is not None:
|
|
124
|
-
t_projected = project(t, state['Q'])
|
|
125
|
-
|
|
126
|
-
# exponential moving averages
|
|
127
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
128
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
129
|
-
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
130
|
-
|
|
131
|
-
if unprojected_exp_avg or t_projected is None:
|
|
132
|
-
exp_avg.lerp_(t, 1-beta1)
|
|
133
|
-
else:
|
|
134
|
-
exp_avg.lerp_(t_projected, 1-beta1)
|
|
135
|
-
|
|
136
|
-
if t_projected is None:
|
|
137
|
-
exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
138
|
-
else:
|
|
139
|
-
exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
140
|
-
|
|
141
|
-
# project exponential moving averages if they are accumulated unprojected
|
|
142
|
-
exp_avg_projected = exp_avg
|
|
143
|
-
if unprojected_exp_avg and t_projected is not None:
|
|
144
|
-
exp_avg_projected = project(exp_avg, state['Q'])
|
|
145
|
-
|
|
146
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
147
|
-
|
|
148
|
-
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
149
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
150
|
-
|
|
151
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
152
|
-
# to the original space
|
|
153
|
-
update = exp_avg_projected / denom
|
|
154
|
-
if t_projected is not None:
|
|
155
|
-
update = project_back(update, state["Q"])
|
|
156
|
-
|
|
157
|
-
if setting['bias_correction']:
|
|
158
|
-
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
159
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
160
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
161
|
-
elif alpha is not None:
|
|
162
|
-
update *= alpha
|
|
163
|
-
|
|
164
|
-
if merge_small:
|
|
165
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
166
|
-
|
|
167
|
-
updates.append(update)
|
|
168
|
-
state["step"] += 1
|
|
169
|
-
|
|
170
|
-
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
171
|
-
if state['GG'] is not None:
|
|
172
|
-
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
173
|
-
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
174
|
-
if state['step'] % setting['precond_freq'] == 0:
|
|
175
|
-
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
|
|
176
|
-
|
|
177
|
-
return updates
|
|
@@ -1,214 +0,0 @@
|
|
|
1
|
-
"""A bunch of useless modules that I hate and that didn't work"""
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...core import Chainable, Transform, apply_transform
|
|
5
|
-
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class CosineStepSize(Transform):
|
|
9
|
-
"""Adaptive step size based on cosine similarity
|
|
10
|
-
|
|
11
|
-
VERDICT: Useless. This is too unstable.
|
|
12
|
-
|
|
13
|
-
Args:
|
|
14
|
-
scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
|
|
15
|
-
init (float, optional): initial step size. Defaults to 1.
|
|
16
|
-
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
17
|
-
target_cossim (float, optional): cosine similarity needs to be above this to increase step size. Defaults to 1e-8.
|
|
18
|
-
inner (Chainable | None, optional):
|
|
19
|
-
inner modules applied after calculating cosine similarity and before step size correction. Defaults to None.
|
|
20
|
-
"""
|
|
21
|
-
def __init__(self, scale:float = 0.95, init:float=1, eps:float=1e-12, inner:Chainable | None = None):
|
|
22
|
-
defaults = dict(scale=scale, init=init, eps=eps)
|
|
23
|
-
super().__init__(defaults, uses_grad=False)
|
|
24
|
-
if inner is not None: self.set_child('inner', inner)
|
|
25
|
-
|
|
26
|
-
@torch.no_grad
|
|
27
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
28
|
-
scale, init = unpack_dicts(settings, 'scale', 'init', cls=NumberList)
|
|
29
|
-
unpack_states(states, tensors, 'alpha', init=init, cls=NumberList) # initializes alpha to init
|
|
30
|
-
eps = settings[0]['eps']
|
|
31
|
-
|
|
32
|
-
tensors = as_tensorlist(tensors)
|
|
33
|
-
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
34
|
-
|
|
35
|
-
tensors_norm = tensors.global_vector_norm()
|
|
36
|
-
cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
|
|
37
|
-
|
|
38
|
-
if 'inner' in self.children:
|
|
39
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
40
|
-
|
|
41
|
-
new_alpha = []
|
|
42
|
-
for s, sc in zip(states, scale):
|
|
43
|
-
s['alpha'] *= 1 + cos_sim * sc
|
|
44
|
-
new_alpha.append(s['alpha'])
|
|
45
|
-
|
|
46
|
-
tensors.mul_(new_alpha)
|
|
47
|
-
prev.copy_(tensors)
|
|
48
|
-
|
|
49
|
-
return tensors
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class CosineDebounce(Transform):
|
|
54
|
-
"""Debouncing when cosine similarity is less than 0.
|
|
55
|
-
|
|
56
|
-
VERDICT: Useless. This doesn't help at all.
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
|
|
60
|
-
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
61
|
-
inner (Chainable | None, optional):
|
|
62
|
-
inner modules applied after calculating cosine similarity and before debouncing correction. Defaults to None.
|
|
63
|
-
"""
|
|
64
|
-
def __init__(self, scale:float = 0.95, eps:float=1e-12, damping:float=0.95, inner:Chainable | None = None):
|
|
65
|
-
defaults = dict(scale=scale, eps=eps, damping=damping)
|
|
66
|
-
super().__init__(defaults, uses_grad=False)
|
|
67
|
-
if inner is not None: self.set_child('inner', inner)
|
|
68
|
-
|
|
69
|
-
@torch.no_grad
|
|
70
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
71
|
-
scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
|
|
72
|
-
eps = settings[0]['eps']
|
|
73
|
-
|
|
74
|
-
tensors = as_tensorlist(tensors)
|
|
75
|
-
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList).mul_(damping)
|
|
76
|
-
|
|
77
|
-
tensors_norm = tensors.global_vector_norm()
|
|
78
|
-
cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
|
|
79
|
-
|
|
80
|
-
if 'inner' in self.children:
|
|
81
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
82
|
-
|
|
83
|
-
if cos_sim < -eps:
|
|
84
|
-
undo = prev.neg().mul_(-cos_sim * scale)
|
|
85
|
-
comb = prev.graft(tensors).add_(tensors).graft_(prev).mul_(-cos_sim*scale)
|
|
86
|
-
tensors = undo.add_(comb)
|
|
87
|
-
|
|
88
|
-
prev.copy_(tensors)
|
|
89
|
-
return tensors
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
class CosineMomentum(Transform):
|
|
94
|
-
"""Beta depends on cosine similarity. At cossim=1, beta is 0. At cossim=-1, beta is 2^power. This basically removes oscillations.
|
|
95
|
-
|
|
96
|
-
VERDICT: Useless. Worse than all other momentums.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
scale (float, optional): cosine similarity multiplier. Defaults to 1.
|
|
100
|
-
nesterov (float, optional): whether to use nesterov momentum. Defaults to False.
|
|
101
|
-
power (float, optional): power for beta. Defaults to 1.
|
|
102
|
-
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
103
|
-
inner (Chainable | None, optional):
|
|
104
|
-
inner modules applied after calculating cosine similarity and before updating exponential moving average. Defaults to None.
|
|
105
|
-
"""
|
|
106
|
-
def __init__(self, scale:float = 1, nesterov: bool = False, power: float = 1, eps:float=1e-12, inner:Chainable | None = None):
|
|
107
|
-
defaults = dict(scale=scale, eps=eps, nesterov=nesterov, power=power)
|
|
108
|
-
super().__init__(defaults, uses_grad=False)
|
|
109
|
-
if inner is not None: self.set_child('inner', inner)
|
|
110
|
-
|
|
111
|
-
@torch.no_grad
|
|
112
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
113
|
-
scale, power = unpack_dicts(settings, 'scale', 'power', cls=NumberList)
|
|
114
|
-
eps = settings[0]['eps']
|
|
115
|
-
nesterov = settings[0]['nesterov']
|
|
116
|
-
exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
|
|
117
|
-
|
|
118
|
-
tensors = as_tensorlist(tensors)
|
|
119
|
-
|
|
120
|
-
tensors_norm = tensors.global_vector_norm()
|
|
121
|
-
cos_sim = (tensors.dot(exp_avg) / (tensors_norm * exp_avg.global_vector_norm()).clip(min=eps)).item()
|
|
122
|
-
|
|
123
|
-
if 'inner' in self.children:
|
|
124
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
125
|
-
|
|
126
|
-
beta = (1 - (cos_sim*scale)) ** power
|
|
127
|
-
if nesterov:
|
|
128
|
-
exp_avg.add_(tensors.mul(beta))
|
|
129
|
-
return tensors.add_(exp_avg)
|
|
130
|
-
else:
|
|
131
|
-
exp_avg.add_(tensors.mul_(beta))
|
|
132
|
-
return exp_avg.clone()
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
class AdaptiveDifference(Transform):
|
|
136
|
-
"""VERDICT: Useless. Doesn't help (sort of to be expected)."""
|
|
137
|
-
def __init__(self, inner:Chainable | None = None):
|
|
138
|
-
defaults = dict()
|
|
139
|
-
super().__init__(defaults, uses_grad=False)
|
|
140
|
-
if inner is not None: self.set_child('inner', inner)
|
|
141
|
-
|
|
142
|
-
@torch.no_grad
|
|
143
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
144
|
-
tensors = as_tensorlist(tensors)
|
|
145
|
-
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
146
|
-
|
|
147
|
-
diff = tensors - prev.graft_(tensors)
|
|
148
|
-
prev.copy_(tensors)
|
|
149
|
-
|
|
150
|
-
if 'inner' in self.children:
|
|
151
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
152
|
-
|
|
153
|
-
tensors.add_(diff.graft_(tensors))
|
|
154
|
-
|
|
155
|
-
return tensors
|
|
156
|
-
|
|
157
|
-
class AdaptiveDifferenceEMA(Transform):
|
|
158
|
-
"""VERDICT: better than non-EMA but still useless."""
|
|
159
|
-
def __init__(self, beta=0.99, inner:Chainable | None = None):
|
|
160
|
-
defaults = dict(beta=beta)
|
|
161
|
-
super().__init__(defaults, uses_grad=False)
|
|
162
|
-
if inner is not None: self.set_child('inner', inner)
|
|
163
|
-
|
|
164
|
-
@torch.no_grad
|
|
165
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
166
|
-
tensors = as_tensorlist(tensors)
|
|
167
|
-
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
168
|
-
prev, diff_exp_avg = unpack_states(states, tensors, 'prev', 'diff_exp_avg', init=[tensors,torch.zeros_like], cls=TensorList)
|
|
169
|
-
|
|
170
|
-
diff = (tensors - prev.graft_(tensors)).graft_(tensors)
|
|
171
|
-
diff_exp_avg.lerp_(diff, 1-beta)
|
|
172
|
-
prev.copy_(tensors)
|
|
173
|
-
|
|
174
|
-
if 'inner' in self.children:
|
|
175
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
176
|
-
|
|
177
|
-
tensors.add_(diff_exp_avg.graft(tensors))
|
|
178
|
-
|
|
179
|
-
return tensors
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
class ScaledAdaptiveDifference(Transform):
|
|
183
|
-
"""VERDICT: Useless and doesn't help."""
|
|
184
|
-
def __init__(self, scale=0.95, damping:float=0.99, inner:Chainable | None = None):
|
|
185
|
-
defaults = dict(scale=scale, damping=damping)
|
|
186
|
-
super().__init__(defaults, uses_grad=False)
|
|
187
|
-
if inner is not None: self.set_child('inner', inner)
|
|
188
|
-
|
|
189
|
-
@torch.no_grad
|
|
190
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
191
|
-
tensors = as_tensorlist(tensors)
|
|
192
|
-
scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
|
|
193
|
-
prev_tensors, prev_update = unpack_states(states, tensors, 'prev', 'prev_update', init=[tensors,tensors], cls=TensorList)
|
|
194
|
-
|
|
195
|
-
cos_sim = (tensors.dot(prev_update) / (tensors.global_vector_norm() * prev_update.global_vector_norm()).clip(min=1e-10)).item()
|
|
196
|
-
|
|
197
|
-
if 'inner' in self.children:
|
|
198
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
199
|
-
|
|
200
|
-
if cos_sim > 0:
|
|
201
|
-
tensors.add_(prev_tensors*(cos_sim*scale))
|
|
202
|
-
|
|
203
|
-
else:
|
|
204
|
-
undo = prev_tensors.neg().mul_(-cos_sim*scale)
|
|
205
|
-
comb = prev_tensors.graft(tensors).add_(tensors).graft_(prev_tensors).mul_(-cos_sim*scale)
|
|
206
|
-
tensors = undo.add_(comb).graft_((tensors-prev_tensors).mul_(damping))
|
|
207
|
-
|
|
208
|
-
diff = tensors - prev_tensors.graft_(tensors)
|
|
209
|
-
prev_tensors.copy_(tensors)
|
|
210
|
-
diff.graft_(tensors)
|
|
211
|
-
tensors.add_(diff)
|
|
212
|
-
prev_update.copy_(tensors)
|
|
213
|
-
|
|
214
|
-
return tensors
|
|
@@ -1,97 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from ...core import Transform
|
|
4
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def signed_cbrt(x: TensorList) -> TensorList:
|
|
8
|
-
return x.sign() * x.abs().pow(1/3)
|
|
9
|
-
|
|
10
|
-
def cubic_adam_(
|
|
11
|
-
tensors: TensorList,
|
|
12
|
-
exp_avg_: TensorList,
|
|
13
|
-
exp_avg_sq_: TensorList,
|
|
14
|
-
exp_avg_cu_: TensorList,
|
|
15
|
-
alpha: float | NumberList,
|
|
16
|
-
beta1: float | NumberList,
|
|
17
|
-
beta2: float | NumberList,
|
|
18
|
-
beta3: float | NumberList,
|
|
19
|
-
eps: float | NumberList,
|
|
20
|
-
debiased: bool,
|
|
21
|
-
step: int,
|
|
22
|
-
):
|
|
23
|
-
exp_avg_.lerp_(tensors, 1-beta1)
|
|
24
|
-
exp_avg_sq_.lerp_(tensors**2, 1-beta2)
|
|
25
|
-
exp_avg_cu_.lerp_(tensors**3, 1-beta3)
|
|
26
|
-
|
|
27
|
-
if debiased:
|
|
28
|
-
m1 = exp_avg_ / (1 - beta1 ** step)
|
|
29
|
-
m2 = exp_avg_sq_ / (1 - beta2 ** step)
|
|
30
|
-
m3 = exp_avg_cu_ / (1 - beta3 ** step)
|
|
31
|
-
else:
|
|
32
|
-
m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
|
|
33
|
-
|
|
34
|
-
# adam minimizes ax^2 + bx
|
|
35
|
-
# we are going to minimize ax^3 + bx^2 + cx
|
|
36
|
-
A = signed_cbrt(m3)
|
|
37
|
-
B = m2.sqrt()
|
|
38
|
-
C = m1
|
|
39
|
-
discriminant = B.pow(2) - 4 * A * C
|
|
40
|
-
|
|
41
|
-
denom = 2 * A
|
|
42
|
-
root = discriminant.clamp(min=0).sqrt_()
|
|
43
|
-
|
|
44
|
-
x0 = (-B + root) / (denom + eps)
|
|
45
|
-
x1 = (-B - root) / (denom + eps)
|
|
46
|
-
|
|
47
|
-
f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
|
|
48
|
-
f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
|
|
49
|
-
|
|
50
|
-
x_star = x0.where(f0 < f1, x1)
|
|
51
|
-
|
|
52
|
-
adam = -C / (B + eps)
|
|
53
|
-
x_star = adam.where(discriminant < 0, x_star)
|
|
54
|
-
|
|
55
|
-
return x_star.mul_(-alpha)
|
|
56
|
-
|
|
57
|
-
class CubicAdam(Transform):
|
|
58
|
-
"""Adam which has 3rd momentum and minimizes a cubic polynomial.
|
|
59
|
-
|
|
60
|
-
VERDICT: can outperform Adam very slightly. Usually very similar performance.
|
|
61
|
-
|
|
62
|
-
.. warning::
|
|
63
|
-
Experimental.
|
|
64
|
-
|
|
65
|
-
"""
|
|
66
|
-
def __init__(
|
|
67
|
-
self,
|
|
68
|
-
beta1: float = 0.9,
|
|
69
|
-
beta2: float = 0.99,
|
|
70
|
-
beta3: float = 0.99,
|
|
71
|
-
eps: float = 1e-8,
|
|
72
|
-
debiased:bool=True,
|
|
73
|
-
alpha: float = 1.,
|
|
74
|
-
):
|
|
75
|
-
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha)
|
|
76
|
-
super().__init__(defaults, uses_grad=False)
|
|
77
|
-
|
|
78
|
-
@torch.no_grad
|
|
79
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
80
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
81
|
-
|
|
82
|
-
beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
|
|
83
|
-
exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
|
|
84
|
-
|
|
85
|
-
return cubic_adam_(
|
|
86
|
-
tensors=TensorList(tensors),
|
|
87
|
-
exp_avg_=exp_avg,
|
|
88
|
-
exp_avg_sq_=exp_avg_sq,
|
|
89
|
-
exp_avg_cu_=exp_avg_cu,
|
|
90
|
-
alpha=alpha,
|
|
91
|
-
beta1=beta1,
|
|
92
|
-
beta2=beta2,
|
|
93
|
-
beta3=beta3,
|
|
94
|
-
eps=eps,
|
|
95
|
-
debiased=settings[0]['debiased'],
|
|
96
|
-
step=step,
|
|
97
|
-
)
|