torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
|
|
13
|
+
"""averages x over first dimension in blocks"""
|
|
14
|
+
if enable and x.ndim >= 2:
|
|
15
|
+
if math.prod(x.shape[1:]) <= 1: return x
|
|
16
|
+
size = x.size(0)
|
|
17
|
+
if block_size is None: return x.mean(0, keepdim=True)
|
|
18
|
+
|
|
19
|
+
n_blocks = size // block_size
|
|
20
|
+
if n_blocks <= 1: return x.mean(0, keepdim = True)
|
|
21
|
+
|
|
22
|
+
n_remaining = size - n_blocks * block_size
|
|
23
|
+
remaining = None
|
|
24
|
+
if n_remaining > 0:
|
|
25
|
+
remaining = x[-n_remaining:].mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
|
|
26
|
+
x = x[:-n_remaining]
|
|
27
|
+
|
|
28
|
+
x = x.view(block_size, n_blocks, *x.shape[1:])
|
|
29
|
+
x_mean = x.mean(0).repeat_interleave(block_size, 0)
|
|
30
|
+
|
|
31
|
+
if remaining is None: return x_mean
|
|
32
|
+
return torch.cat([x_mean, remaining], 0)
|
|
33
|
+
|
|
34
|
+
return x
|
|
35
|
+
|
|
36
|
+
def _rademacher_like(tensor, p = 0.5, generator = None):
|
|
37
|
+
"""p is probability of a 1, other values will be -1."""
|
|
38
|
+
return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
|
|
39
|
+
|
|
40
|
+
def adahessian(
|
|
41
|
+
tensors: TensorList,
|
|
42
|
+
D: TensorList | None,
|
|
43
|
+
exp_avg_: TensorList,
|
|
44
|
+
D_exp_avg_sq_: TensorList,
|
|
45
|
+
beta1: float | NumberList,
|
|
46
|
+
beta2: float | NumberList,
|
|
47
|
+
update_freq: int,
|
|
48
|
+
eps: float | NumberList,
|
|
49
|
+
step: int,
|
|
50
|
+
):
|
|
51
|
+
# momentum
|
|
52
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
53
|
+
num = exp_avg_ / (1-beta1)
|
|
54
|
+
|
|
55
|
+
# update preconditioner
|
|
56
|
+
if step % update_freq == 0:
|
|
57
|
+
assert D is not None
|
|
58
|
+
D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
assert D is None
|
|
62
|
+
|
|
63
|
+
denom = (D_exp_avg_sq_ / (1-beta2)).sqrt_().add_(eps)
|
|
64
|
+
|
|
65
|
+
return num.div_(denom)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class AdaHessian(Module):
|
|
69
|
+
"""AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
|
|
70
|
+
|
|
71
|
+
This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
|
|
72
|
+
|
|
73
|
+
.. note::
|
|
74
|
+
In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply AdaHessian preconditioning to another module's output.
|
|
75
|
+
|
|
76
|
+
.. note::
|
|
77
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
78
|
+
|
|
79
|
+
.. note::
|
|
80
|
+
This module requires a closure passed to the optimizer step,
|
|
81
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
82
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
beta1 (float, optional): first momentum. Defaults to 0.9.
|
|
86
|
+
beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
|
|
87
|
+
averaging (bool, optional):
|
|
88
|
+
whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
|
|
89
|
+
This can be set per-parameter in param groups.
|
|
90
|
+
block_size (int, optional):
|
|
91
|
+
size of block in the block-diagonal averaging.
|
|
92
|
+
update_freq (int, optional):
|
|
93
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
94
|
+
This value can be increased to reduce computational cost. Defaults to 1.
|
|
95
|
+
eps (float, optional):
|
|
96
|
+
division stability epsilon. Defaults to 1e-8.
|
|
97
|
+
hvp_method (str, optional):
|
|
98
|
+
Determines how Hessian-vector products are evaluated.
|
|
99
|
+
|
|
100
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
101
|
+
This requires creating a graph for the gradient.
|
|
102
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
103
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
104
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
105
|
+
more accurate HVP approximation. This requires two extra
|
|
106
|
+
gradient evaluations.
|
|
107
|
+
Defaults to "autograd".
|
|
108
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
109
|
+
n_samples (int, optional):
|
|
110
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
111
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
112
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
113
|
+
inner (Chainable | None, optional):
|
|
114
|
+
Inner module. If this is specified, operations are performed in the following order.
|
|
115
|
+
1. compute hessian diagonal estimate.
|
|
116
|
+
2. pass inputs to :code:`inner`.
|
|
117
|
+
3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
|
|
118
|
+
|
|
119
|
+
Examples:
|
|
120
|
+
Using AdaHessian:
|
|
121
|
+
|
|
122
|
+
.. code-block:: python
|
|
123
|
+
|
|
124
|
+
opt = tz.Modular(
|
|
125
|
+
model.parameters(),
|
|
126
|
+
tz.m.AdaHessian(),
|
|
127
|
+
tz.m.LR(0.1)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
AdaHessian preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
|
|
131
|
+
Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
|
|
132
|
+
AdaHessian preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
133
|
+
|
|
134
|
+
.. code-block:: python
|
|
135
|
+
|
|
136
|
+
opt = tz.Modular(
|
|
137
|
+
model.parameters(),
|
|
138
|
+
tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
|
|
139
|
+
tz.m.LR(0.1)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
"""
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
beta1: float = 0.9,
|
|
146
|
+
beta2: float = 0.999,
|
|
147
|
+
averaging: bool = False,
|
|
148
|
+
block_size: int | None = 9,
|
|
149
|
+
update_freq: int = 1,
|
|
150
|
+
eps: float = 1e-8,
|
|
151
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
152
|
+
fd_h: float = 1e-3,
|
|
153
|
+
n_samples = 1,
|
|
154
|
+
seed: int | None = None,
|
|
155
|
+
inner: Chainable | None = None
|
|
156
|
+
):
|
|
157
|
+
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
158
|
+
super().__init__(defaults)
|
|
159
|
+
|
|
160
|
+
if inner is not None:
|
|
161
|
+
self.set_child('inner', inner)
|
|
162
|
+
|
|
163
|
+
@torch.no_grad
|
|
164
|
+
def step(self, var):
|
|
165
|
+
params = var.params
|
|
166
|
+
settings = self.settings[params[0]]
|
|
167
|
+
hvp_method = settings['hvp_method']
|
|
168
|
+
fd_h = settings['fd_h']
|
|
169
|
+
update_freq = settings['update_freq']
|
|
170
|
+
n_samples = settings['n_samples']
|
|
171
|
+
|
|
172
|
+
seed = settings['seed']
|
|
173
|
+
generator = None
|
|
174
|
+
if seed is not None:
|
|
175
|
+
if 'generator' not in self.global_state:
|
|
176
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
177
|
+
generator = self.global_state['generator']
|
|
178
|
+
|
|
179
|
+
beta1, beta2, eps, averaging, block_size = self.get_settings(params,
|
|
180
|
+
'beta1', 'beta2', 'eps', 'averaging', 'block_size', cls=NumberList)
|
|
181
|
+
|
|
182
|
+
exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
183
|
+
|
|
184
|
+
step = self.global_state.get('step', 0)
|
|
185
|
+
self.global_state['step'] = step + 1
|
|
186
|
+
|
|
187
|
+
closure = var.closure
|
|
188
|
+
assert closure is not None
|
|
189
|
+
|
|
190
|
+
D = None
|
|
191
|
+
if step % update_freq == 0:
|
|
192
|
+
|
|
193
|
+
rgrad=None
|
|
194
|
+
for i in range(n_samples):
|
|
195
|
+
u = [_rademacher_like(p, generator=generator) for p in params]
|
|
196
|
+
|
|
197
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
198
|
+
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
199
|
+
|
|
200
|
+
if D is None: D = Hvp
|
|
201
|
+
else: torch._foreach_add_(D, Hvp)
|
|
202
|
+
|
|
203
|
+
assert D is not None
|
|
204
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
205
|
+
|
|
206
|
+
D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
|
|
207
|
+
|
|
208
|
+
update = var.get_update()
|
|
209
|
+
if 'inner' in self.children:
|
|
210
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
211
|
+
|
|
212
|
+
var.update = adahessian(
|
|
213
|
+
tensors=TensorList(update),
|
|
214
|
+
D=TensorList(D) if D is not None else None,
|
|
215
|
+
exp_avg_=exp_avg,
|
|
216
|
+
D_exp_avg_sq_=D_exp_avg_sq,
|
|
217
|
+
beta1=beta1,
|
|
218
|
+
beta2=beta2,
|
|
219
|
+
update_freq=update_freq,
|
|
220
|
+
eps=eps,
|
|
221
|
+
step=step,
|
|
222
|
+
)
|
|
223
|
+
return var
|
|
@@ -3,14 +3,14 @@ from functools import partial
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
6
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
8
|
from ..functional import (
|
|
9
9
|
debias, debiased_step_size,
|
|
10
10
|
ema_,
|
|
11
11
|
sqrt_ema_sq_,
|
|
12
12
|
)
|
|
13
|
-
from ..
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
14
|
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
15
|
from ..momentum.momentum import nag_
|
|
16
16
|
|
|
@@ -27,26 +27,28 @@ def adam_(
|
|
|
27
27
|
pow: float = 2,
|
|
28
28
|
debiased: bool = True,
|
|
29
29
|
max_exp_avg_sq_: TensorList | None = None,
|
|
30
|
-
params_: TensorList | None = None,
|
|
31
|
-
):
|
|
32
|
-
"""Returns new tensors or updates params in-place."""
|
|
33
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
34
30
|
|
|
31
|
+
# inner args
|
|
32
|
+
inner: Module | None = None,
|
|
33
|
+
params: list[torch.Tensor] | None = None,
|
|
34
|
+
grads: list[torch.Tensor] | None = None,
|
|
35
|
+
):
|
|
36
|
+
"""Returns new tensors."""
|
|
35
37
|
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
36
38
|
debiased=False,step=step,pow=pow)
|
|
37
39
|
|
|
38
|
-
if
|
|
40
|
+
if inner is not None:
|
|
41
|
+
assert params is not None
|
|
42
|
+
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
39
43
|
|
|
40
|
-
|
|
41
|
-
if
|
|
44
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
45
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
46
|
+
return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
|
|
42
47
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
return None
|
|
48
|
+
class Adam(Transform):
|
|
49
|
+
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
|
|
46
50
|
|
|
47
|
-
|
|
48
|
-
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
|
|
49
|
-
pytorch in that debiasing is applied after adding epsilon.
|
|
51
|
+
This implementation is identical to :code:`torch.optim.Adam`.
|
|
50
52
|
|
|
51
53
|
Args:
|
|
52
54
|
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
@@ -66,36 +68,29 @@ class Adam(Module):
|
|
|
66
68
|
alpha: float = 1.,
|
|
67
69
|
pow: float = 2,
|
|
68
70
|
debiased: bool = True,
|
|
71
|
+
inner: Chainable | None = None
|
|
69
72
|
):
|
|
70
73
|
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
71
|
-
super().__init__(defaults)
|
|
72
|
-
|
|
74
|
+
super().__init__(defaults, uses_grad=False)
|
|
75
|
+
|
|
76
|
+
if inner is not None: self.set_child('inner', inner)
|
|
73
77
|
|
|
74
78
|
@torch.no_grad
|
|
75
|
-
def
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
76
80
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
77
81
|
|
|
78
|
-
beta1,beta2,eps,alpha=
|
|
79
|
-
amsgrad,pow,debiased =
|
|
82
|
+
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
83
|
+
amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
|
|
80
84
|
|
|
81
85
|
if amsgrad:
|
|
82
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
86
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
83
87
|
else:
|
|
84
|
-
exp_avg, exp_avg_sq =
|
|
88
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
85
89
|
max_exp_avg_sq = None
|
|
86
90
|
|
|
87
|
-
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
88
|
-
if vars.is_last:
|
|
89
|
-
if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
|
|
90
|
-
passed_params = TensorList(vars.params)
|
|
91
|
-
vars.stop = True
|
|
92
|
-
vars.skip_update = True
|
|
93
|
-
|
|
94
|
-
else:
|
|
95
|
-
passed_params = None
|
|
96
91
|
|
|
97
|
-
|
|
98
|
-
tensors=TensorList(
|
|
92
|
+
return adam_(
|
|
93
|
+
tensors=TensorList(tensors),
|
|
99
94
|
exp_avg_=exp_avg,
|
|
100
95
|
exp_avg_sq_=exp_avg_sq,
|
|
101
96
|
alpha=alpha,
|
|
@@ -106,7 +101,10 @@ class Adam(Module):
|
|
|
106
101
|
pow=pow,
|
|
107
102
|
debiased=debiased,
|
|
108
103
|
max_exp_avg_sq_=max_exp_avg_sq,
|
|
109
|
-
params_=passed_params,
|
|
110
|
-
)
|
|
111
104
|
|
|
112
|
-
|
|
105
|
+
# inner args
|
|
106
|
+
inner=self.children.get("inner", None),
|
|
107
|
+
params=params,
|
|
108
|
+
grads=grads,
|
|
109
|
+
|
|
110
|
+
)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
+
|
|
6
|
+
def adan_(
|
|
7
|
+
g: TensorList,
|
|
8
|
+
g_prev_: TensorList,
|
|
9
|
+
m_: TensorList, # exponential moving average
|
|
10
|
+
v_: TensorList, # exponential moving average of gradient differences
|
|
11
|
+
n_: TensorList, # kinda like squared momentum
|
|
12
|
+
n_prev_: TensorList | None,
|
|
13
|
+
beta1: float | NumberList,
|
|
14
|
+
beta2: float | NumberList,
|
|
15
|
+
beta3: float | NumberList,
|
|
16
|
+
eps: float | NumberList,
|
|
17
|
+
use_n_prev: bool,
|
|
18
|
+
):
|
|
19
|
+
"""Returns new tensors."""
|
|
20
|
+
m_.lerp_(g, 1-beta1)
|
|
21
|
+
|
|
22
|
+
y = g - g_prev_
|
|
23
|
+
v_.lerp_(y, 1-beta2)
|
|
24
|
+
|
|
25
|
+
y.mul_(1-beta2).add_(g)
|
|
26
|
+
n_.mul_(beta3).addcmul_(y, y, 1-beta3)
|
|
27
|
+
|
|
28
|
+
if use_n_prev:
|
|
29
|
+
assert n_prev_ is not None
|
|
30
|
+
ns = n_prev_.clone()
|
|
31
|
+
n_prev_.copy_(n_)
|
|
32
|
+
n_ = ns
|
|
33
|
+
|
|
34
|
+
eta = n_.sqrt().add_(eps).reciprocal_()
|
|
35
|
+
term = m_ + (1-beta2)*v_
|
|
36
|
+
update = eta.mul_(term)
|
|
37
|
+
|
|
38
|
+
g_prev_.copy_(g)
|
|
39
|
+
|
|
40
|
+
return update
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Adan(Transform):
|
|
44
|
+
"""Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
beta1 (float, optional): momentum. Defaults to 0.98.
|
|
48
|
+
beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
|
|
49
|
+
beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
|
|
50
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
51
|
+
use_n_prev (bool, optional):
|
|
52
|
+
whether to use previous gradient differences momentum.
|
|
53
|
+
|
|
54
|
+
Reference:
|
|
55
|
+
Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
|
|
56
|
+
"""
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
beta1: float = 0.98,
|
|
60
|
+
beta2: float = 0.92,
|
|
61
|
+
beta3: float = 0.99,
|
|
62
|
+
eps: float = 1e-8,
|
|
63
|
+
use_n_prev: bool = False,
|
|
64
|
+
):
|
|
65
|
+
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,use_n_prev=use_n_prev)
|
|
66
|
+
super().__init__(defaults, uses_grad=False)
|
|
67
|
+
|
|
68
|
+
@torch.no_grad
|
|
69
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
70
|
+
tensors = TensorList(tensors)
|
|
71
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
|
+
|
|
73
|
+
beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
|
|
74
|
+
s = settings[0]
|
|
75
|
+
use_n_prev = s['use_n_prev']
|
|
76
|
+
|
|
77
|
+
g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if use_n_prev:
|
|
81
|
+
n_prev = unpack_states(states, tensors, 'n_prev', cls=TensorList)
|
|
82
|
+
else:
|
|
83
|
+
n_prev = None
|
|
84
|
+
|
|
85
|
+
if step == 1:
|
|
86
|
+
# initial values, also runs on restarts
|
|
87
|
+
m.copy_(tensors)
|
|
88
|
+
n.set_(tensors ** 2)
|
|
89
|
+
v.zero_()
|
|
90
|
+
g_prev.copy_(tensors)
|
|
91
|
+
if n_prev is not None: n_prev.set_(tensors ** 2)
|
|
92
|
+
|
|
93
|
+
if step == 2:
|
|
94
|
+
v.set_(tensors - g_prev)
|
|
95
|
+
|
|
96
|
+
update = adan_(
|
|
97
|
+
g=tensors,
|
|
98
|
+
g_prev_=g_prev,
|
|
99
|
+
m_=m,
|
|
100
|
+
v_=v,
|
|
101
|
+
n_=n,
|
|
102
|
+
n_prev_=n_prev,
|
|
103
|
+
beta1=beta1,
|
|
104
|
+
beta2=beta2,
|
|
105
|
+
beta3=beta3,
|
|
106
|
+
eps=eps,
|
|
107
|
+
use_n_prev=use_n_prev,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return update
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Transform
|
|
3
|
+
from ...utils import TensorList, unpack_dicts, unpack_states
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
|
|
7
|
+
if f - f_star <= torch.finfo(p[0].dtype).eps: return g
|
|
8
|
+
|
|
9
|
+
g_g = g.dot(g)
|
|
10
|
+
g_gp = g.dot(g_prev)
|
|
11
|
+
num = -(f - f_star) * g.dot(g_prev)
|
|
12
|
+
denom = (f_prev - f_star) * g_g + (f - f_star) * g_gp
|
|
13
|
+
m = num/denom
|
|
14
|
+
|
|
15
|
+
h = 2*(f - f_star) / g_g
|
|
16
|
+
return (1 + m) * h * g - m*(p-p_prev)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AdaptiveHeavyBall(Transform):
|
|
20
|
+
"""Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
|
|
21
|
+
|
|
22
|
+
This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
|
|
23
|
+
|
|
24
|
+
.. note::
|
|
25
|
+
The step size is determined by the algorithm, so learning rate modules shouldn't be used.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
f_star (int, optional):
|
|
29
|
+
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
30
|
+
tol (float, optional):
|
|
31
|
+
tolerance on objective value change.
|
|
32
|
+
"""
|
|
33
|
+
def __init__(self, f_star: float = 0):
|
|
34
|
+
defaults = dict(f_star=f_star)
|
|
35
|
+
super().__init__(defaults, uses_grad=False, uses_loss=True)
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
39
|
+
assert loss is not None
|
|
40
|
+
tensors = TensorList(tensors)
|
|
41
|
+
setting = settings[0]
|
|
42
|
+
f_star = setting['f_star']
|
|
43
|
+
|
|
44
|
+
f_prev = self.global_state.get('f_prev', None)
|
|
45
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
|
|
46
|
+
|
|
47
|
+
if f_prev is None:
|
|
48
|
+
self.global_state['f_prev'] = loss
|
|
49
|
+
h = 2*(loss - f_star) / tensors.dot(tensors)
|
|
50
|
+
return h * tensors
|
|
51
|
+
|
|
52
|
+
update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
|
|
53
|
+
|
|
54
|
+
self.global_state['f_prev'] = loss
|
|
55
|
+
p_prev.copy_(params)
|
|
56
|
+
g_prev.copy_(tensors)
|
|
57
|
+
return update
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def esgd_(
|
|
12
|
+
tensors_: TensorList,
|
|
13
|
+
D: TensorList | None,
|
|
14
|
+
D_sq_acc_: TensorList,
|
|
15
|
+
damping: float | NumberList,
|
|
16
|
+
update_freq: int,
|
|
17
|
+
step: int,
|
|
18
|
+
i: int,
|
|
19
|
+
):
|
|
20
|
+
# update preconditioner
|
|
21
|
+
if step % update_freq == 0:
|
|
22
|
+
assert D is not None
|
|
23
|
+
D_sq_acc_.addcmul_(D, D)
|
|
24
|
+
i += 1
|
|
25
|
+
else:
|
|
26
|
+
assert D is None
|
|
27
|
+
|
|
28
|
+
denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
|
|
29
|
+
return tensors_.div_(denom), i
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ESGD(Module):
|
|
33
|
+
"""Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
|
|
34
|
+
|
|
35
|
+
This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
|
|
36
|
+
|
|
37
|
+
.. note::
|
|
38
|
+
In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
|
|
39
|
+
|
|
40
|
+
.. note::
|
|
41
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
42
|
+
|
|
43
|
+
.. note::
|
|
44
|
+
This module requires a closure passed to the optimizer step,
|
|
45
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
46
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
damping (float, optional): added to denominator for stability. Defaults to 1e-4.
|
|
50
|
+
update_freq (int, optional):
|
|
51
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
52
|
+
This value can be increased to reduce computational cost. Defaults to 20.
|
|
53
|
+
hvp_method (str, optional):
|
|
54
|
+
Determines how Hessian-vector products are evaluated.
|
|
55
|
+
|
|
56
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
57
|
+
This requires creating a graph for the gradient.
|
|
58
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
59
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
60
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
61
|
+
more accurate HVP approximation. This requires two extra
|
|
62
|
+
gradient evaluations.
|
|
63
|
+
Defaults to "autograd".
|
|
64
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
65
|
+
n_samples (int, optional):
|
|
66
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
67
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
68
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
69
|
+
inner (Chainable | None, optional):
|
|
70
|
+
Inner module. If this is specified, operations are performed in the following order.
|
|
71
|
+
1. compute hessian diagonal estimate.
|
|
72
|
+
2. pass inputs to :code:`inner`.
|
|
73
|
+
3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
|
|
74
|
+
|
|
75
|
+
Examples:
|
|
76
|
+
Using ESGD:
|
|
77
|
+
|
|
78
|
+
.. code-block:: python
|
|
79
|
+
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.ESGD(),
|
|
83
|
+
tz.m.LR(0.1)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
|
|
87
|
+
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
88
|
+
|
|
89
|
+
.. code-block:: python
|
|
90
|
+
|
|
91
|
+
opt = tz.Modular(
|
|
92
|
+
model.parameters(),
|
|
93
|
+
tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
|
|
94
|
+
tz.m.LR(0.1)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
damping: float = 1e-4,
|
|
101
|
+
update_freq: int = 20,
|
|
102
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
103
|
+
fd_h: float = 1e-3,
|
|
104
|
+
n_samples = 1,
|
|
105
|
+
seed: int | None = None,
|
|
106
|
+
inner: Chainable | None = None
|
|
107
|
+
):
|
|
108
|
+
defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
109
|
+
super().__init__(defaults)
|
|
110
|
+
|
|
111
|
+
if inner is not None:
|
|
112
|
+
self.set_child('inner', inner)
|
|
113
|
+
|
|
114
|
+
@torch.no_grad
|
|
115
|
+
def step(self, var):
|
|
116
|
+
params = var.params
|
|
117
|
+
settings = self.settings[params[0]]
|
|
118
|
+
hvp_method = settings['hvp_method']
|
|
119
|
+
fd_h = settings['fd_h']
|
|
120
|
+
update_freq = settings['update_freq']
|
|
121
|
+
n_samples = settings['n_samples']
|
|
122
|
+
|
|
123
|
+
seed = settings['seed']
|
|
124
|
+
generator = None
|
|
125
|
+
if seed is not None:
|
|
126
|
+
if 'generator' not in self.global_state:
|
|
127
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
128
|
+
generator = self.global_state['generator']
|
|
129
|
+
|
|
130
|
+
damping = self.get_settings(params, 'damping', cls=NumberList)
|
|
131
|
+
D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
|
|
132
|
+
i = self.global_state.get('i', 0)
|
|
133
|
+
|
|
134
|
+
step = self.global_state.get('step', 0)
|
|
135
|
+
self.global_state['step'] = step + 1
|
|
136
|
+
|
|
137
|
+
closure = var.closure
|
|
138
|
+
assert closure is not None
|
|
139
|
+
|
|
140
|
+
D = None
|
|
141
|
+
if step % update_freq == 0:
|
|
142
|
+
|
|
143
|
+
rgrad=None
|
|
144
|
+
for j in range(n_samples):
|
|
145
|
+
u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
|
|
146
|
+
|
|
147
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
148
|
+
h=fd_h, normalize=True, retain_grad=j < n_samples-1)
|
|
149
|
+
|
|
150
|
+
if D is None: D = Hvp
|
|
151
|
+
else: torch._foreach_add_(D, Hvp)
|
|
152
|
+
|
|
153
|
+
assert D is not None
|
|
154
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
155
|
+
|
|
156
|
+
D = TensorList(D)
|
|
157
|
+
|
|
158
|
+
update = var.get_update()
|
|
159
|
+
if 'inner' in self.children:
|
|
160
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
161
|
+
|
|
162
|
+
var.update, self.global_state['i'] = esgd_(
|
|
163
|
+
tensors_=TensorList(update),
|
|
164
|
+
D=TensorList(D) if D is not None else None,
|
|
165
|
+
D_sq_acc_=D_sq_acc,
|
|
166
|
+
damping=damping,
|
|
167
|
+
update_freq=update_freq,
|
|
168
|
+
step=step,
|
|
169
|
+
i=i,
|
|
170
|
+
)
|
|
171
|
+
return var
|