torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,38 +1,42 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from collections.abc import Callable
|
|
3
2
|
from typing import Literal
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
|
|
7
6
|
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
8
7
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
|
-
from
|
|
8
|
+
from ..functional import debiased_step_size
|
|
10
9
|
|
|
10
|
+
def _full_average(hvp: torch.Tensor):
|
|
11
|
+
if hvp.ndim >= 3: # Conv kernel
|
|
12
|
+
return torch.mean(hvp.abs(), dim=[2, *range(3,hvp.ndim)], keepdim=True)
|
|
13
|
+
return hvp
|
|
11
14
|
|
|
12
15
|
def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
|
|
13
16
|
"""averages x over first dimension in blocks"""
|
|
14
17
|
if enable and x.ndim >= 2:
|
|
15
18
|
if math.prod(x.shape[1:]) <= 1: return x
|
|
19
|
+
if block_size is None: return _full_average(x)
|
|
16
20
|
size = x.size(0)
|
|
17
|
-
if block_size is None: return x.mean(0, keepdim=True)
|
|
18
21
|
|
|
19
22
|
n_blocks = size // block_size
|
|
20
|
-
if n_blocks <= 1: return x.mean(0, keepdim = True)
|
|
23
|
+
if n_blocks <= 1: return x.abs().mean(0, keepdim = True)
|
|
21
24
|
|
|
22
25
|
n_remaining = size - n_blocks * block_size
|
|
23
26
|
remaining = None
|
|
24
27
|
if n_remaining > 0:
|
|
25
|
-
remaining = x[-n_remaining:].mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
|
|
28
|
+
remaining = x[-n_remaining:].abs().mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
|
|
26
29
|
x = x[:-n_remaining]
|
|
27
30
|
|
|
28
31
|
x = x.view(block_size, n_blocks, *x.shape[1:])
|
|
29
|
-
x_mean = x.mean(0).repeat_interleave(block_size, 0)
|
|
32
|
+
x_mean = x.abs().mean(0).repeat_interleave(block_size, 0)
|
|
30
33
|
|
|
31
34
|
if remaining is None: return x_mean
|
|
32
35
|
return torch.cat([x_mean, remaining], 0)
|
|
33
36
|
|
|
34
37
|
return x
|
|
35
38
|
|
|
39
|
+
|
|
36
40
|
def _rademacher_like(tensor, p = 0.5, generator = None):
|
|
37
41
|
"""p is probability of a 1, other values will be -1."""
|
|
38
42
|
return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
|
|
@@ -46,11 +50,11 @@ def adahessian(
|
|
|
46
50
|
beta2: float | NumberList,
|
|
47
51
|
update_freq: int,
|
|
48
52
|
eps: float | NumberList,
|
|
53
|
+
hessian_power: float | NumberList,
|
|
49
54
|
step: int,
|
|
50
55
|
):
|
|
51
56
|
# momentum
|
|
52
57
|
exp_avg_.lerp_(tensors, 1-beta1)
|
|
53
|
-
num = exp_avg_ / (1-beta1)
|
|
54
58
|
|
|
55
59
|
# update preconditioner
|
|
56
60
|
if step % update_freq == 0:
|
|
@@ -60,7 +64,9 @@ def adahessian(
|
|
|
60
64
|
else:
|
|
61
65
|
assert D is None
|
|
62
66
|
|
|
63
|
-
|
|
67
|
+
|
|
68
|
+
denom = D_exp_avg_sq_.sqrt().pow_(hessian_power).add_(eps)
|
|
69
|
+
num = exp_avg_ * debiased_step_size(step+1, beta1, beta2)
|
|
64
70
|
|
|
65
71
|
return num.div_(denom)
|
|
66
72
|
|
|
@@ -70,16 +76,12 @@ class AdaHessian(Module):
|
|
|
70
76
|
|
|
71
77
|
This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
|
|
72
78
|
|
|
73
|
-
|
|
74
|
-
In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the
|
|
79
|
+
Notes:
|
|
80
|
+
- In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.
|
|
75
81
|
|
|
76
|
-
|
|
77
|
-
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
82
|
+
- If you are using gradient estimators or reformulations, set ``hvp_method`` to "forward" or "central".
|
|
78
83
|
|
|
79
|
-
|
|
80
|
-
This module requires a closure passed to the optimizer step,
|
|
81
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
82
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
84
|
+
- This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
83
85
|
|
|
84
86
|
Args:
|
|
85
87
|
beta1 (float, optional): first momentum. Defaults to 0.9.
|
|
@@ -105,7 +107,7 @@ class AdaHessian(Module):
|
|
|
105
107
|
more accurate HVP approximation. This requires two extra
|
|
106
108
|
gradient evaluations.
|
|
107
109
|
Defaults to "autograd".
|
|
108
|
-
|
|
110
|
+
fd_h (float, optional): finite difference step size if ``hvp_method`` is "forward" or "central". Defaults to 1e-3.
|
|
109
111
|
n_samples (int, optional):
|
|
110
112
|
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
111
113
|
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
@@ -113,48 +115,49 @@ class AdaHessian(Module):
|
|
|
113
115
|
inner (Chainable | None, optional):
|
|
114
116
|
Inner module. If this is specified, operations are performed in the following order.
|
|
115
117
|
1. compute hessian diagonal estimate.
|
|
116
|
-
2. pass inputs to
|
|
117
|
-
3. momentum and preconditioning are applied to the ouputs of
|
|
118
|
-
|
|
119
|
-
Examples:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
118
|
+
2. pass inputs to ``inner``.
|
|
119
|
+
3. momentum and preconditioning are applied to the ouputs of ``inner``.
|
|
120
|
+
|
|
121
|
+
## Examples:
|
|
122
|
+
|
|
123
|
+
Using AdaHessian:
|
|
124
|
+
|
|
125
|
+
```python
|
|
126
|
+
opt = tz.Modular(
|
|
127
|
+
model.parameters(),
|
|
128
|
+
tz.m.AdaHessian(),
|
|
129
|
+
tz.m.LR(0.1)
|
|
130
|
+
)
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
AdaHessian preconditioner can be applied to any other module by passing it to the ``inner`` argument.
|
|
134
|
+
Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
|
|
135
|
+
AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
|
|
136
|
+
```python
|
|
137
|
+
opt = tz.Modular(
|
|
138
|
+
model.parameters(),
|
|
139
|
+
tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
|
|
140
|
+
tz.m.LR(0.1)
|
|
141
|
+
)
|
|
142
|
+
```
|
|
141
143
|
|
|
142
144
|
"""
|
|
143
145
|
def __init__(
|
|
144
146
|
self,
|
|
145
147
|
beta1: float = 0.9,
|
|
146
148
|
beta2: float = 0.999,
|
|
147
|
-
averaging: bool =
|
|
148
|
-
block_size: int | None =
|
|
149
|
+
averaging: bool = True,
|
|
150
|
+
block_size: int | None = None,
|
|
149
151
|
update_freq: int = 1,
|
|
150
152
|
eps: float = 1e-8,
|
|
153
|
+
hessian_power: float = 1,
|
|
151
154
|
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
152
155
|
fd_h: float = 1e-3,
|
|
153
156
|
n_samples = 1,
|
|
154
157
|
seed: int | None = None,
|
|
155
158
|
inner: Chainable | None = None
|
|
156
159
|
):
|
|
157
|
-
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
160
|
+
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hessian_power=hessian_power, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
158
161
|
super().__init__(defaults)
|
|
159
162
|
|
|
160
163
|
if inner is not None:
|
|
@@ -170,14 +173,10 @@ class AdaHessian(Module):
|
|
|
170
173
|
n_samples = settings['n_samples']
|
|
171
174
|
|
|
172
175
|
seed = settings['seed']
|
|
173
|
-
generator =
|
|
174
|
-
if seed is not None:
|
|
175
|
-
if 'generator' not in self.global_state:
|
|
176
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
177
|
-
generator = self.global_state['generator']
|
|
176
|
+
generator = self.get_generator(params[0].device, seed)
|
|
178
177
|
|
|
179
|
-
beta1, beta2, eps, averaging, block_size = self.get_settings(params,
|
|
180
|
-
'beta1', 'beta2', 'eps', 'averaging', 'block_size', cls=NumberList)
|
|
178
|
+
beta1, beta2, eps, averaging, block_size, hessian_power = self.get_settings(params,
|
|
179
|
+
'beta1', 'beta2', 'eps', 'averaging', 'block_size', "hessian_power", cls=NumberList)
|
|
181
180
|
|
|
182
181
|
exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
183
182
|
|
|
@@ -196,6 +195,7 @@ class AdaHessian(Module):
|
|
|
196
195
|
|
|
197
196
|
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
198
197
|
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
198
|
+
Hvp = tuple(Hvp)
|
|
199
199
|
|
|
200
200
|
if D is None: D = Hvp
|
|
201
201
|
else: torch._foreach_add_(D, Hvp)
|
|
@@ -218,6 +218,7 @@ class AdaHessian(Module):
|
|
|
218
218
|
beta2=beta2,
|
|
219
219
|
update_freq=update_freq,
|
|
220
220
|
eps=eps,
|
|
221
|
+
hessian_power=hessian_power,
|
|
221
222
|
step=step,
|
|
222
223
|
)
|
|
223
224
|
return var
|
|
@@ -9,37 +9,38 @@ def adan_(
|
|
|
9
9
|
m_: TensorList, # exponential moving average
|
|
10
10
|
v_: TensorList, # exponential moving average of gradient differences
|
|
11
11
|
n_: TensorList, # kinda like squared momentum
|
|
12
|
-
n_prev_: TensorList | None,
|
|
13
12
|
beta1: float | NumberList,
|
|
14
13
|
beta2: float | NumberList,
|
|
15
14
|
beta3: float | NumberList,
|
|
16
15
|
eps: float | NumberList,
|
|
17
|
-
|
|
16
|
+
step: int,
|
|
18
17
|
):
|
|
19
|
-
"""Returns new tensors
|
|
20
|
-
m_.lerp_(g, 1-beta1)
|
|
18
|
+
"""Returns new tensors"""
|
|
19
|
+
m_.lerp_(g, 1 - beta1)
|
|
21
20
|
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
if step == 1:
|
|
22
|
+
term = g
|
|
23
|
+
else:
|
|
24
|
+
diff = g - g_prev_
|
|
25
|
+
v_.lerp_(diff, 1 - beta2)
|
|
26
|
+
term = g + beta2 * diff
|
|
24
27
|
|
|
25
|
-
|
|
26
|
-
n_.mul_(beta3).addcmul_(y, y, 1-beta3)
|
|
28
|
+
n_.mul_(beta3).addcmul_(term, term, value=(1 - beta3))
|
|
27
29
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
n_prev_.copy_(n_)
|
|
32
|
-
n_ = ns
|
|
30
|
+
m = m_ / (1.0 - beta1**step)
|
|
31
|
+
v = v_ / (1.0 - beta2**step)
|
|
32
|
+
n = n_ / (1.0 - beta3**step)
|
|
33
33
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
update = eta.mul_(term)
|
|
34
|
+
denom = n.sqrt_().add_(eps)
|
|
35
|
+
num = m + beta2 * v
|
|
37
36
|
|
|
37
|
+
update = num.div_(denom)
|
|
38
38
|
g_prev_.copy_(g)
|
|
39
39
|
|
|
40
40
|
return update
|
|
41
41
|
|
|
42
42
|
|
|
43
|
+
|
|
43
44
|
class Adan(Transform):
|
|
44
45
|
"""Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
|
|
45
46
|
|
|
@@ -51,6 +52,13 @@ class Adan(Transform):
|
|
|
51
52
|
use_n_prev (bool, optional):
|
|
52
53
|
whether to use previous gradient differences momentum.
|
|
53
54
|
|
|
55
|
+
Example:
|
|
56
|
+
```python
|
|
57
|
+
opt = tz.Modular(
|
|
58
|
+
model.parameters(),
|
|
59
|
+
tz.m.Adan(),
|
|
60
|
+
tz.m.LR(1e-3),
|
|
61
|
+
)
|
|
54
62
|
Reference:
|
|
55
63
|
Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
|
|
56
64
|
"""
|
|
@@ -60,9 +68,8 @@ class Adan(Transform):
|
|
|
60
68
|
beta2: float = 0.92,
|
|
61
69
|
beta3: float = 0.99,
|
|
62
70
|
eps: float = 1e-8,
|
|
63
|
-
use_n_prev: bool = False,
|
|
64
71
|
):
|
|
65
|
-
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps
|
|
72
|
+
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps)
|
|
66
73
|
super().__init__(defaults, uses_grad=False)
|
|
67
74
|
|
|
68
75
|
@torch.no_grad
|
|
@@ -71,40 +78,19 @@ class Adan(Transform):
|
|
|
71
78
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
79
|
|
|
73
80
|
beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
|
|
74
|
-
s = settings[0]
|
|
75
|
-
use_n_prev = s['use_n_prev']
|
|
76
|
-
|
|
77
81
|
g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
|
|
78
82
|
|
|
79
|
-
|
|
80
|
-
if use_n_prev:
|
|
81
|
-
n_prev = unpack_states(states, tensors, 'n_prev', cls=TensorList)
|
|
82
|
-
else:
|
|
83
|
-
n_prev = None
|
|
84
|
-
|
|
85
|
-
if step == 1:
|
|
86
|
-
# initial values, also runs on restarts
|
|
87
|
-
m.copy_(tensors)
|
|
88
|
-
n.set_(tensors ** 2)
|
|
89
|
-
v.zero_()
|
|
90
|
-
g_prev.copy_(tensors)
|
|
91
|
-
if n_prev is not None: n_prev.set_(tensors ** 2)
|
|
92
|
-
|
|
93
|
-
if step == 2:
|
|
94
|
-
v.set_(tensors - g_prev)
|
|
95
|
-
|
|
96
83
|
update = adan_(
|
|
97
84
|
g=tensors,
|
|
98
85
|
g_prev_=g_prev,
|
|
99
86
|
m_=m,
|
|
100
87
|
v_=v,
|
|
101
88
|
n_=n,
|
|
102
|
-
n_prev_=n_prev,
|
|
103
89
|
beta1=beta1,
|
|
104
90
|
beta2=beta2,
|
|
105
91
|
beta3=beta3,
|
|
106
92
|
eps=eps,
|
|
107
|
-
|
|
93
|
+
step=step,
|
|
108
94
|
)
|
|
109
95
|
|
|
110
96
|
return update
|
|
@@ -4,7 +4,7 @@ from ...utils import TensorList, unpack_dicts, unpack_states
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
|
|
7
|
-
if f - f_star <= torch.finfo(p[0].dtype).
|
|
7
|
+
if f - f_star <= torch.finfo(p[0].dtype).tiny * 2: return g
|
|
8
8
|
|
|
9
9
|
g_g = g.dot(g)
|
|
10
10
|
g_gp = g.dot(g_prev)
|
|
@@ -21,14 +21,12 @@ class AdaptiveHeavyBall(Transform):
|
|
|
21
21
|
|
|
22
22
|
This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
note:
|
|
25
25
|
The step size is determined by the algorithm, so learning rate modules shouldn't be used.
|
|
26
26
|
|
|
27
27
|
Args:
|
|
28
28
|
f_star (int, optional):
|
|
29
29
|
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
30
|
-
tol (float, optional):
|
|
31
|
-
tolerance on objective value change.
|
|
32
30
|
"""
|
|
33
31
|
def __init__(self, f_star: float = 0):
|
|
34
32
|
defaults = dict(f_star=f_star)
|
|
@@ -38,8 +36,7 @@ class AdaptiveHeavyBall(Transform):
|
|
|
38
36
|
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
39
37
|
assert loss is not None
|
|
40
38
|
tensors = TensorList(tensors)
|
|
41
|
-
|
|
42
|
-
f_star = setting['f_star']
|
|
39
|
+
f_star = self.defaults['f_star']
|
|
43
40
|
|
|
44
41
|
f_prev = self.global_state.get('f_prev', None)
|
|
45
42
|
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Transform
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
+
|
|
8
|
+
# i've verified, it is identical to official
|
|
9
|
+
# https://github.com/txping/AEGD/blob/master/aegd.py
|
|
10
|
+
def aegd_(f: torch.Tensor | float, g: TensorList, r_: TensorList, c:float|NumberList=1, eta:float|NumberList=0.1) -> TensorList:
|
|
11
|
+
v = g / (2 * (f + c)**0.5)
|
|
12
|
+
r_ /= 1 + (v ** 2).mul_(2*eta) # update energy
|
|
13
|
+
return 2*eta * r_*v # pyright:ignore[reportReturnType]
|
|
14
|
+
|
|
15
|
+
class AEGD(Transform):
|
|
16
|
+
"""AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
AEGD has a learning rate hyperparameter that can't really be removed from the update rule.
|
|
20
|
+
To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
eta (float, optional): step size. Defaults to 0.1.
|
|
24
|
+
c (float, optional): c. Defaults to 1.
|
|
25
|
+
beta3 (float, optional): thrid (squared) momentum. Defaults to 0.1.
|
|
26
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
27
|
+
use_n_prev (bool, optional):
|
|
28
|
+
whether to use previous gradient differences momentum.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
lr: float = 0.1,
|
|
33
|
+
c: float = 1,
|
|
34
|
+
):
|
|
35
|
+
defaults=dict(c=c,lr=lr)
|
|
36
|
+
super().__init__(defaults, uses_loss=True)
|
|
37
|
+
|
|
38
|
+
@torch.no_grad
|
|
39
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
|
+
assert loss is not None
|
|
41
|
+
tensors = TensorList(tensors)
|
|
42
|
+
|
|
43
|
+
c,lr=unpack_dicts(settings, 'c','lr', cls=NumberList)
|
|
44
|
+
r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss+c[0])**0.5), cls=TensorList)
|
|
45
|
+
|
|
46
|
+
update = aegd_(
|
|
47
|
+
f=loss,
|
|
48
|
+
g=tensors,
|
|
49
|
+
r_=r,
|
|
50
|
+
c=c,
|
|
51
|
+
eta=lr,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return update
|
|
@@ -61,7 +61,7 @@ class ESGD(Module):
|
|
|
61
61
|
more accurate HVP approximation. This requires two extra
|
|
62
62
|
gradient evaluations.
|
|
63
63
|
Defaults to "autograd".
|
|
64
|
-
|
|
64
|
+
fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
65
65
|
n_samples (int, optional):
|
|
66
66
|
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
67
67
|
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
@@ -5,8 +5,12 @@ import warnings
|
|
|
5
5
|
import torch
|
|
6
6
|
from ...core import Chainable, TensorwiseTransform
|
|
7
7
|
|
|
8
|
-
def lm_adagrad_update(history: deque[torch.Tensor], damping, rdamping):
|
|
9
|
-
|
|
8
|
+
def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping):
|
|
9
|
+
if isinstance(history, torch.Tensor):
|
|
10
|
+
M = history
|
|
11
|
+
else:
|
|
12
|
+
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
13
|
+
|
|
10
14
|
MTM = M.T @ M
|
|
11
15
|
if damping != 0:
|
|
12
16
|
MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
|
|
@@ -58,47 +62,45 @@ class LMAdagrad(TensorwiseTransform):
|
|
|
58
62
|
order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
|
|
59
63
|
true_damping (bool, optional):
|
|
60
64
|
If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
|
|
61
|
-
eigh (bool, optional): uses a more efficient way to calculate U and S. Defaults to True.
|
|
62
65
|
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
63
|
-
|
|
66
|
+
L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
|
|
64
67
|
interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
|
|
65
68
|
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
|
|
66
69
|
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
67
70
|
|
|
68
|
-
Examples:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
71
|
+
## Examples:
|
|
72
|
+
|
|
73
|
+
Limited-memory Adagrad
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
optimizer = tz.Modular(
|
|
77
|
+
model.parameters(),
|
|
78
|
+
tz.m.LMAdagrad(),
|
|
79
|
+
tz.m.LR(0.1)
|
|
80
|
+
)
|
|
81
|
+
```
|
|
82
|
+
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
optimizer = tz.Modular(
|
|
86
|
+
model.parameters(),
|
|
87
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
88
|
+
tz.m.Debias(0.9, 0.999),
|
|
89
|
+
tz.m.LR(0.01)
|
|
90
|
+
)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
optimizer = tz.Modular(
|
|
97
|
+
model.parameters(),
|
|
98
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
99
|
+
tz.m.Debias(0.9, 0.999),
|
|
100
|
+
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
101
|
+
tz.m.LR(0.01)
|
|
102
|
+
)
|
|
103
|
+
```
|
|
102
104
|
Reference:
|
|
103
105
|
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
104
106
|
"""
|
|
@@ -143,6 +145,7 @@ class LMAdagrad(TensorwiseTransform):
|
|
|
143
145
|
# scaled by parameter differences
|
|
144
146
|
cur_p = param.clone()
|
|
145
147
|
cur_g = tensor.clone()
|
|
148
|
+
eps = torch.finfo(cur_p.dtype).tiny * 2
|
|
146
149
|
for i in range(1, order):
|
|
147
150
|
if f'prev_g_{i}' not in state:
|
|
148
151
|
state[f'prev_p_{i}'] = cur_p
|
|
@@ -157,7 +160,7 @@ class LMAdagrad(TensorwiseTransform):
|
|
|
157
160
|
cur_g = y
|
|
158
161
|
|
|
159
162
|
if i == order - 1:
|
|
160
|
-
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=
|
|
163
|
+
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
|
|
161
164
|
history.append(cur_g.view(-1))
|
|
162
165
|
|
|
163
166
|
step = state.get('step', 0)
|
|
@@ -1,18 +1,7 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
1
|
import torch
|
|
5
2
|
|
|
6
|
-
from ...core import
|
|
3
|
+
from ...core import Transform
|
|
7
4
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
-
from ..functional import (
|
|
9
|
-
debias, debiased_step_size,
|
|
10
|
-
ema_,
|
|
11
|
-
sqrt_ema_sq_,
|
|
12
|
-
)
|
|
13
|
-
from ..step_size.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
5
|
|
|
17
6
|
|
|
18
7
|
def mars_correction_(
|
|
@@ -35,36 +24,35 @@ class MARSCorrection(Transform):
|
|
|
35
24
|
"""MARS variance reduction correction.
|
|
36
25
|
|
|
37
26
|
Place any other momentum-based optimizer after this,
|
|
38
|
-
make sure
|
|
27
|
+
make sure ``beta`` parameter matches with momentum in the optimizer.
|
|
39
28
|
|
|
40
29
|
Args:
|
|
41
30
|
beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
|
|
42
31
|
scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
|
|
43
32
|
max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
|
|
44
33
|
|
|
45
|
-
Examples:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
)
|
|
34
|
+
## Examples:
|
|
35
|
+
|
|
36
|
+
Mars-AdamW
|
|
37
|
+
```python
|
|
38
|
+
optimizer = tz.Modular(
|
|
39
|
+
model.parameters(),
|
|
40
|
+
tz.m.MARSCorrection(beta=0.95),
|
|
41
|
+
tz.m.Adam(beta1=0.95, beta2=0.99),
|
|
42
|
+
tz.m.WeightDecay(1e-3),
|
|
43
|
+
tz.m.LR(0.1)
|
|
44
|
+
)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Mars-Lion
|
|
48
|
+
```python
|
|
49
|
+
optimizer = tz.Modular(
|
|
50
|
+
model.parameters(),
|
|
51
|
+
tz.m.MARSCorrection(beta=0.9),
|
|
52
|
+
tz.m.Lion(beta1=0.9),
|
|
53
|
+
tz.m.LR(0.1)
|
|
54
|
+
)
|
|
55
|
+
```
|
|
68
56
|
|
|
69
57
|
"""
|
|
70
58
|
def __init__(
|