torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def esgd_(
|
|
12
|
+
tensors_: TensorList,
|
|
13
|
+
D: TensorList | None,
|
|
14
|
+
D_sq_acc_: TensorList,
|
|
15
|
+
damping: float | NumberList,
|
|
16
|
+
update_freq: int,
|
|
17
|
+
step: int,
|
|
18
|
+
i: int,
|
|
19
|
+
):
|
|
20
|
+
# update preconditioner
|
|
21
|
+
if step % update_freq == 0:
|
|
22
|
+
assert D is not None
|
|
23
|
+
D_sq_acc_.addcmul_(D, D)
|
|
24
|
+
i += 1
|
|
25
|
+
else:
|
|
26
|
+
assert D is None
|
|
27
|
+
|
|
28
|
+
denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
|
|
29
|
+
return tensors_.div_(denom), i
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ESGD(Module):
|
|
33
|
+
"""Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
|
|
34
|
+
|
|
35
|
+
This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
|
|
36
|
+
|
|
37
|
+
.. note::
|
|
38
|
+
In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
|
|
39
|
+
|
|
40
|
+
.. note::
|
|
41
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
42
|
+
|
|
43
|
+
.. note::
|
|
44
|
+
This module requires a closure passed to the optimizer step,
|
|
45
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
46
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
damping (float, optional): added to denominator for stability. Defaults to 1e-4.
|
|
50
|
+
update_freq (int, optional):
|
|
51
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
52
|
+
This value can be increased to reduce computational cost. Defaults to 20.
|
|
53
|
+
hvp_method (str, optional):
|
|
54
|
+
Determines how Hessian-vector products are evaluated.
|
|
55
|
+
|
|
56
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
57
|
+
This requires creating a graph for the gradient.
|
|
58
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
59
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
60
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
61
|
+
more accurate HVP approximation. This requires two extra
|
|
62
|
+
gradient evaluations.
|
|
63
|
+
Defaults to "autograd".
|
|
64
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
65
|
+
n_samples (int, optional):
|
|
66
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
67
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
68
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
69
|
+
inner (Chainable | None, optional):
|
|
70
|
+
Inner module. If this is specified, operations are performed in the following order.
|
|
71
|
+
1. compute hessian diagonal estimate.
|
|
72
|
+
2. pass inputs to :code:`inner`.
|
|
73
|
+
3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
|
|
74
|
+
|
|
75
|
+
Examples:
|
|
76
|
+
Using ESGD:
|
|
77
|
+
|
|
78
|
+
.. code-block:: python
|
|
79
|
+
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.ESGD(),
|
|
83
|
+
tz.m.LR(0.1)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
|
|
87
|
+
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
88
|
+
|
|
89
|
+
.. code-block:: python
|
|
90
|
+
|
|
91
|
+
opt = tz.Modular(
|
|
92
|
+
model.parameters(),
|
|
93
|
+
tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
|
|
94
|
+
tz.m.LR(0.1)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
damping: float = 1e-4,
|
|
101
|
+
update_freq: int = 20,
|
|
102
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
103
|
+
fd_h: float = 1e-3,
|
|
104
|
+
n_samples = 1,
|
|
105
|
+
seed: int | None = None,
|
|
106
|
+
inner: Chainable | None = None
|
|
107
|
+
):
|
|
108
|
+
defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
109
|
+
super().__init__(defaults)
|
|
110
|
+
|
|
111
|
+
if inner is not None:
|
|
112
|
+
self.set_child('inner', inner)
|
|
113
|
+
|
|
114
|
+
@torch.no_grad
|
|
115
|
+
def step(self, var):
|
|
116
|
+
params = var.params
|
|
117
|
+
settings = self.settings[params[0]]
|
|
118
|
+
hvp_method = settings['hvp_method']
|
|
119
|
+
fd_h = settings['fd_h']
|
|
120
|
+
update_freq = settings['update_freq']
|
|
121
|
+
n_samples = settings['n_samples']
|
|
122
|
+
|
|
123
|
+
seed = settings['seed']
|
|
124
|
+
generator = None
|
|
125
|
+
if seed is not None:
|
|
126
|
+
if 'generator' not in self.global_state:
|
|
127
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
128
|
+
generator = self.global_state['generator']
|
|
129
|
+
|
|
130
|
+
damping = self.get_settings(params, 'damping', cls=NumberList)
|
|
131
|
+
D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
|
|
132
|
+
i = self.global_state.get('i', 0)
|
|
133
|
+
|
|
134
|
+
step = self.global_state.get('step', 0)
|
|
135
|
+
self.global_state['step'] = step + 1
|
|
136
|
+
|
|
137
|
+
closure = var.closure
|
|
138
|
+
assert closure is not None
|
|
139
|
+
|
|
140
|
+
D = None
|
|
141
|
+
if step % update_freq == 0:
|
|
142
|
+
|
|
143
|
+
rgrad=None
|
|
144
|
+
for j in range(n_samples):
|
|
145
|
+
u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
|
|
146
|
+
|
|
147
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
148
|
+
h=fd_h, normalize=True, retain_grad=j < n_samples-1)
|
|
149
|
+
|
|
150
|
+
if D is None: D = Hvp
|
|
151
|
+
else: torch._foreach_add_(D, Hvp)
|
|
152
|
+
|
|
153
|
+
assert D is not None
|
|
154
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
155
|
+
|
|
156
|
+
D = TensorList(D)
|
|
157
|
+
|
|
158
|
+
update = var.get_update()
|
|
159
|
+
if 'inner' in self.children:
|
|
160
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
161
|
+
|
|
162
|
+
var.update, self.global_state['i'] = esgd_(
|
|
163
|
+
tensors_=TensorList(update),
|
|
164
|
+
D=TensorList(D) if D is not None else None,
|
|
165
|
+
D_sq_acc_=D_sq_acc,
|
|
166
|
+
damping=damping,
|
|
167
|
+
update_freq=update_freq,
|
|
168
|
+
step=step,
|
|
169
|
+
i=i,
|
|
170
|
+
)
|
|
171
|
+
return var
|
|
@@ -1,55 +1,53 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
import math
|
|
3
1
|
from collections import deque
|
|
4
2
|
from typing import Literal, Any
|
|
5
|
-
import
|
|
3
|
+
import warnings
|
|
6
4
|
|
|
7
5
|
import torch
|
|
8
6
|
from ...core import Chainable, TensorwiseTransform
|
|
9
|
-
from ...utils.linalg.matrix_funcs import matrix_power_eigh
|
|
10
|
-
from ...utils.linalg.svd import randomized_svd
|
|
11
|
-
from ...utils.linalg.qr import qr_householder
|
|
12
7
|
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
8
|
+
def lm_adagrad_update(history: deque[torch.Tensor], damping, rdamping):
|
|
9
|
+
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
10
|
+
MTM = M.T @ M
|
|
11
|
+
if damping != 0:
|
|
12
|
+
MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
|
|
17
13
|
|
|
18
14
|
try:
|
|
19
|
-
|
|
20
|
-
U = U.to(device); S = S.to(device)
|
|
15
|
+
L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
|
|
21
16
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
S.pow_(2)
|
|
27
|
-
Iu **= 2
|
|
28
|
-
S.add_(Iu)
|
|
29
|
-
if true_damping: S.sqrt_()
|
|
17
|
+
tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
|
|
18
|
+
indices = L > tol
|
|
19
|
+
L = L[indices]
|
|
20
|
+
Q = Q[:, indices]
|
|
30
21
|
|
|
31
|
-
|
|
22
|
+
U = (M @ Q) * L.rsqrt()
|
|
23
|
+
|
|
24
|
+
if rdamping != 0:
|
|
25
|
+
rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
|
|
26
|
+
L.add_(rdamping)
|
|
27
|
+
|
|
28
|
+
return U, L
|
|
32
29
|
|
|
33
30
|
except torch.linalg.LinAlgError:
|
|
34
31
|
return None, None
|
|
35
32
|
|
|
36
|
-
def
|
|
37
|
-
|
|
38
|
-
return U @
|
|
39
|
-
|
|
33
|
+
def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
|
|
34
|
+
Z = U.T @ g
|
|
35
|
+
return (U * L.rsqrt()) @ Z
|
|
40
36
|
|
|
41
37
|
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
42
38
|
if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
|
|
43
39
|
else:
|
|
44
|
-
if state_[key].shape != value.shape: state_[key] = value
|
|
40
|
+
if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
|
|
45
41
|
else: state_[key].lerp_(value, 1-beta)
|
|
46
42
|
|
|
47
|
-
class
|
|
43
|
+
class LMAdagrad(TensorwiseTransform):
|
|
48
44
|
"""
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
45
|
+
Limited-memory full matrix Adagrad.
|
|
46
|
+
|
|
47
|
+
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
|
|
48
|
+
But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
|
|
49
|
+
|
|
50
|
+
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
53
51
|
|
|
54
52
|
Args:
|
|
55
53
|
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
@@ -60,55 +58,84 @@ class SpectralPreconditioner(TensorwiseTransform):
|
|
|
60
58
|
order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
|
|
61
59
|
true_damping (bool, optional):
|
|
62
60
|
If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
|
|
61
|
+
eigh (bool, optional): uses a more efficient way to calculate U and S. Defaults to True.
|
|
63
62
|
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
64
|
-
S_beta (float | None, optional): momentum for
|
|
63
|
+
S_beta (float | None, optional): momentum for S (too unstable, don't use). Defaults to None.
|
|
65
64
|
interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
|
|
66
|
-
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to
|
|
67
|
-
normalize (bool, optional): whether to normalize gradients, this doesn't work well so don't use it. Defaults to False.
|
|
68
|
-
centralize (bool, optional): whether to centralize gradients, this doesn't work well so don't use it. Defaults to False.
|
|
65
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
|
|
69
66
|
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
67
|
+
|
|
68
|
+
Examples:
|
|
69
|
+
Limited-memory Adagrad
|
|
70
|
+
|
|
71
|
+
.. code-block:: python
|
|
72
|
+
|
|
73
|
+
optimizer = tz.Modular(
|
|
74
|
+
model.parameters(),
|
|
75
|
+
tz.m.LMAdagrad(),
|
|
76
|
+
tz.m.LR(0.1)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
80
|
+
|
|
81
|
+
.. code-block:: python
|
|
82
|
+
|
|
83
|
+
optimizer = tz.Modular(
|
|
84
|
+
model.parameters(),
|
|
85
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
86
|
+
tz.m.Debias(0.9, 0.999),
|
|
87
|
+
tz.m.LR(0.01)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
91
|
+
|
|
92
|
+
.. code-block:: python
|
|
93
|
+
|
|
94
|
+
optimizer = tz.Modular(
|
|
95
|
+
model.parameters(),
|
|
96
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
97
|
+
tz.m.Debias(0.9, 0.999),
|
|
98
|
+
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
99
|
+
tz.m.LR(0.01)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
Reference:
|
|
103
|
+
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
70
104
|
"""
|
|
71
105
|
|
|
72
106
|
def __init__(
|
|
73
107
|
self,
|
|
74
|
-
history_size: int =
|
|
108
|
+
history_size: int = 100,
|
|
75
109
|
update_freq: int = 1,
|
|
76
110
|
damping: float = 1e-4,
|
|
77
111
|
rdamping: float = 0,
|
|
78
112
|
order: int = 1,
|
|
79
113
|
true_damping: bool = True,
|
|
80
114
|
U_beta: float | None = None,
|
|
81
|
-
|
|
115
|
+
L_beta: float | None = None,
|
|
82
116
|
interval: int = 1,
|
|
83
|
-
concat_params: bool =
|
|
84
|
-
normalize: bool=False,
|
|
85
|
-
centralize:bool = False,
|
|
117
|
+
concat_params: bool = True,
|
|
86
118
|
inner: Chainable | None = None,
|
|
87
119
|
):
|
|
88
120
|
# history is still updated each step so Precondition's update_freq has different meaning
|
|
89
|
-
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta,
|
|
121
|
+
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
|
|
90
122
|
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
|
|
91
123
|
|
|
92
124
|
@torch.no_grad
|
|
93
|
-
def update_tensor(self, tensor, param, grad, loss, state,
|
|
94
|
-
order =
|
|
95
|
-
history_size =
|
|
96
|
-
update_freq =
|
|
97
|
-
damping =
|
|
98
|
-
rdamping =
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
S_beta = settings['S_beta']
|
|
102
|
-
normalize = settings['normalize']
|
|
103
|
-
centralize = settings['centralize']
|
|
125
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
126
|
+
order = setting['order']
|
|
127
|
+
history_size = setting['history_size']
|
|
128
|
+
update_freq = setting['update_freq']
|
|
129
|
+
damping = setting['damping']
|
|
130
|
+
rdamping = setting['rdamping']
|
|
131
|
+
U_beta = setting['U_beta']
|
|
132
|
+
L_beta = setting['L_beta']
|
|
104
133
|
|
|
105
134
|
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
106
135
|
history = state['history']
|
|
107
136
|
|
|
108
137
|
if order == 1:
|
|
109
138
|
t = tensor.clone().view(-1)
|
|
110
|
-
if centralize: t -= t.mean()
|
|
111
|
-
if normalize: t /= torch.linalg.vector_norm(t).clip(min=1e-8) # pylint:disable=not-callable
|
|
112
139
|
history.append(t)
|
|
113
140
|
else:
|
|
114
141
|
|
|
@@ -122,42 +149,35 @@ class SpectralPreconditioner(TensorwiseTransform):
|
|
|
122
149
|
state[f'prev_g_{i}'] = cur_g
|
|
123
150
|
break
|
|
124
151
|
|
|
125
|
-
|
|
126
|
-
|
|
152
|
+
s = cur_p - state[f'prev_p_{i}']
|
|
153
|
+
y = cur_g - state[f'prev_g_{i}']
|
|
127
154
|
state[f'prev_p_{i}'] = cur_p
|
|
128
155
|
state[f'prev_g_{i}'] = cur_g
|
|
129
|
-
cur_p =
|
|
130
|
-
cur_g =
|
|
156
|
+
cur_p = s
|
|
157
|
+
cur_g = y
|
|
131
158
|
|
|
132
159
|
if i == order - 1:
|
|
133
|
-
|
|
134
|
-
if normalize: cur_g = cur_g / torch.linalg.vector_norm(cur_g).clip(min=1e-8) # pylint:disable=not-callable
|
|
135
|
-
else: cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
|
|
160
|
+
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
|
|
136
161
|
history.append(cur_g.view(-1))
|
|
137
162
|
|
|
138
163
|
step = state.get('step', 0)
|
|
139
164
|
if step % update_freq == 0 and len(history) != 0:
|
|
140
|
-
U,
|
|
165
|
+
U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
|
|
141
166
|
maybe_lerp_(state, U_beta, 'U', U)
|
|
142
|
-
maybe_lerp_(state,
|
|
167
|
+
maybe_lerp_(state, L_beta, 'L', L)
|
|
143
168
|
|
|
144
169
|
if len(history) != 0:
|
|
145
170
|
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
146
171
|
|
|
147
172
|
@torch.no_grad
|
|
148
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
149
|
-
history_size = settings['history_size']
|
|
150
|
-
|
|
173
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
151
174
|
U = state.get('U', None)
|
|
152
175
|
if U is None:
|
|
153
176
|
# make a conservative step to avoid issues due to different GD scaling
|
|
154
177
|
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
155
178
|
|
|
156
|
-
|
|
157
|
-
update =
|
|
179
|
+
L = state['L']
|
|
180
|
+
update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
|
|
158
181
|
|
|
159
|
-
n = len(state['history'])
|
|
160
|
-
mh = min(history_size, 10)
|
|
161
|
-
if n <= mh: update.mul_(n/mh)
|
|
162
182
|
return update
|
|
163
183
|
|
|
@@ -28,7 +28,7 @@ class Lion(Transform):
|
|
|
28
28
|
super().__init__(defaults, uses_grad=False)
|
|
29
29
|
|
|
30
30
|
@torch.no_grad
|
|
31
|
-
def
|
|
31
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
32
32
|
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
33
33
|
exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
34
34
|
return lion_(TensorList(tensors),exp_avg,beta1,beta2)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def mars_correction_(
|
|
19
|
+
tensors_: TensorList,
|
|
20
|
+
prev_: TensorList,
|
|
21
|
+
beta: float | NumberList,
|
|
22
|
+
scaling: float | NumberList,
|
|
23
|
+
max_norm: float | NumberList | None,
|
|
24
|
+
):
|
|
25
|
+
dg = (tensors_ - prev_).mul_(scaling * beta / (1-beta))
|
|
26
|
+
prev_.copy_(tensors_)
|
|
27
|
+
|
|
28
|
+
c = tensors_.add_(dg)
|
|
29
|
+
if max_norm is not None:
|
|
30
|
+
c.clip_norm_(max=max_norm, tensorwise=False)
|
|
31
|
+
|
|
32
|
+
return c
|
|
33
|
+
|
|
34
|
+
class MARSCorrection(Transform):
|
|
35
|
+
"""MARS variance reduction correction.
|
|
36
|
+
|
|
37
|
+
Place any other momentum-based optimizer after this,
|
|
38
|
+
make sure :code:`beta` parameter matches with momentum in the optimizer.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
|
|
42
|
+
scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
|
|
43
|
+
max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
Mars-AdamW
|
|
47
|
+
|
|
48
|
+
.. code-block:: python
|
|
49
|
+
|
|
50
|
+
optimizer = tz.Modular(
|
|
51
|
+
model.parameters(),
|
|
52
|
+
tz.m.MARSCorrection(beta=0.95),
|
|
53
|
+
tz.m.Adam(beta1=0.95, beta2=0.99),
|
|
54
|
+
tz.m.WeightDecay(1e-3),
|
|
55
|
+
tz.m.LR(0.1)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
Mars-Lion
|
|
59
|
+
|
|
60
|
+
.. code-block:: python
|
|
61
|
+
|
|
62
|
+
optimizer = tz.Modular(
|
|
63
|
+
model.parameters(),
|
|
64
|
+
tz.m.MARSCorrection(beta=0.9),
|
|
65
|
+
tz.m.Lion(beta1=0.9),
|
|
66
|
+
tz.m.LR(0.1)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
"""
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
beta: float = 0.9,
|
|
73
|
+
scaling: float = 0.025,
|
|
74
|
+
max_norm: float | None = 1,
|
|
75
|
+
):
|
|
76
|
+
defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
77
|
+
super().__init__(defaults, uses_grad=False)
|
|
78
|
+
|
|
79
|
+
@torch.no_grad
|
|
80
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
81
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
82
|
+
beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
|
|
83
|
+
max_norm = settings[0]['max_norm']
|
|
84
|
+
|
|
85
|
+
return mars_correction_(
|
|
86
|
+
tensors_=TensorList(tensors),
|
|
87
|
+
prev_=prev,
|
|
88
|
+
beta=beta,
|
|
89
|
+
scaling=scaling,
|
|
90
|
+
max_norm=max_norm,
|
|
91
|
+
)
|