torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Transform,
|
|
5
|
+
from ...core import Chainable, Transform, apply_transform
|
|
6
6
|
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
7
|
|
|
8
8
|
@torch.no_grad
|
|
@@ -24,11 +24,9 @@ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
|
24
24
|
Projects the gradient to the eigenbases of the preconditioner.
|
|
25
25
|
"""
|
|
26
26
|
for mat in Q:
|
|
27
|
-
if mat is None:
|
|
28
|
-
if len(mat) > 0:
|
|
27
|
+
if mat is not None and len(mat) > 0:
|
|
29
28
|
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
29
|
else:
|
|
31
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
32
30
|
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
33
31
|
tensors = tensors.permute(permute_order)
|
|
34
32
|
|
|
@@ -40,8 +38,7 @@ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
|
40
38
|
Projects the gradient back to the original space.
|
|
41
39
|
"""
|
|
42
40
|
for mat in Q:
|
|
43
|
-
if mat is None:
|
|
44
|
-
if len(mat) > 0:
|
|
41
|
+
if mat is not None and len(mat) > 0:
|
|
45
42
|
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
46
43
|
else:
|
|
47
44
|
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
@@ -59,8 +56,7 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
|
59
56
|
float_data = False
|
|
60
57
|
original_type = original_device = None
|
|
61
58
|
for m in mat:
|
|
62
|
-
if m is None:
|
|
63
|
-
if len(m) == 0:
|
|
59
|
+
if m is None or len(m) == 0:
|
|
64
60
|
matrix.append([])
|
|
65
61
|
continue
|
|
66
62
|
if m.dtype != torch.float:
|
|
@@ -100,13 +96,11 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
100
96
|
float_data = False
|
|
101
97
|
original_type = original_device = None
|
|
102
98
|
for m,o in zip(GG, Q_list):
|
|
103
|
-
if m is None:
|
|
104
|
-
assert o is not None
|
|
105
|
-
|
|
106
|
-
if len(m) == 0:
|
|
99
|
+
if m is None or len(m) == 0:
|
|
107
100
|
matrix.append([])
|
|
108
101
|
orth_matrix.append([])
|
|
109
102
|
continue
|
|
103
|
+
assert o is not None
|
|
110
104
|
if m.data.dtype != torch.float:
|
|
111
105
|
original_type = m.data.dtype
|
|
112
106
|
original_device = m.data.device
|
|
@@ -152,11 +146,28 @@ class SOAP(Transform):
|
|
|
152
146
|
epsilon for dividing first momentum by second. Defaults to 1e-8.
|
|
153
147
|
decay (float | None, optional):
|
|
154
148
|
Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
results but True usually works better. Defaults to True.
|
|
149
|
+
alpha (float, optional):
|
|
150
|
+
learning rate. Defaults to 1.
|
|
158
151
|
bias_correction (bool, optional):
|
|
159
152
|
enables adam bias correction. Defaults to True.
|
|
153
|
+
|
|
154
|
+
Examples:
|
|
155
|
+
SOAP:
|
|
156
|
+
|
|
157
|
+
.. code-block:: python
|
|
158
|
+
|
|
159
|
+
opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))
|
|
160
|
+
|
|
161
|
+
Stabilized SOAP:
|
|
162
|
+
|
|
163
|
+
.. code-block:: python
|
|
164
|
+
|
|
165
|
+
opt = tz.Modular(
|
|
166
|
+
model.parameters(),
|
|
167
|
+
tz.m.SOAP(),
|
|
168
|
+
tz.m.NormalizeByEMA(max_ema_growth=1.2),
|
|
169
|
+
tz.m.LR(1e-2)
|
|
170
|
+
)
|
|
160
171
|
"""
|
|
161
172
|
def __init__(
|
|
162
173
|
self,
|
|
@@ -170,7 +181,6 @@ class SOAP(Transform):
|
|
|
170
181
|
eps: float = 1e-8,
|
|
171
182
|
decay: float | None = None,
|
|
172
183
|
alpha: float = 1,
|
|
173
|
-
unprojected_exp_avg: bool = True,
|
|
174
184
|
bias_correction: bool = True,
|
|
175
185
|
):
|
|
176
186
|
defaults = dict(
|
|
@@ -183,21 +193,18 @@ class SOAP(Transform):
|
|
|
183
193
|
precondition_1d=precondition_1d,
|
|
184
194
|
eps=eps,
|
|
185
195
|
decay=decay,
|
|
186
|
-
unprojected_exp_avg=unprojected_exp_avg,
|
|
187
196
|
bias_correction=bias_correction,
|
|
188
197
|
alpha=alpha,
|
|
189
198
|
)
|
|
190
199
|
super().__init__(defaults, uses_grad=False)
|
|
191
200
|
|
|
192
201
|
@torch.no_grad
|
|
193
|
-
def
|
|
202
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
194
203
|
updates = []
|
|
195
204
|
# update preconditioners
|
|
196
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
200
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
|
|
205
|
+
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
206
|
+
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps,alpha = itemgetter(
|
|
207
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps','alpha')(setting)
|
|
201
208
|
|
|
202
209
|
if merge_small:
|
|
203
210
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -205,7 +212,7 @@ class SOAP(Transform):
|
|
|
205
212
|
# initialize state on 1st step
|
|
206
213
|
if 'GG' not in state:
|
|
207
214
|
state["exp_avg"] = torch.zeros_like(t)
|
|
208
|
-
state["
|
|
215
|
+
state["exp_avg_sq_projected"] = torch.zeros_like(t)
|
|
209
216
|
|
|
210
217
|
if not precondition_1d and t.ndim <= 1:
|
|
211
218
|
state['GG'] = []
|
|
@@ -235,35 +242,31 @@ class SOAP(Transform):
|
|
|
235
242
|
# exponential moving averages
|
|
236
243
|
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
237
244
|
exp_avg: torch.Tensor = state["exp_avg"]
|
|
238
|
-
|
|
245
|
+
exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
|
|
239
246
|
|
|
240
|
-
|
|
241
|
-
exp_avg.lerp_(t, 1-beta1)
|
|
242
|
-
else:
|
|
243
|
-
exp_avg.lerp_(t_projected, 1-beta1)
|
|
247
|
+
exp_avg.lerp_(t, 1-beta1)
|
|
244
248
|
|
|
245
249
|
if t_projected is None:
|
|
246
|
-
|
|
250
|
+
exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
247
251
|
else:
|
|
248
|
-
|
|
252
|
+
exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
249
253
|
|
|
250
254
|
# project exponential moving averages if they are accumulated unprojected
|
|
251
255
|
exp_avg_projected = exp_avg
|
|
252
|
-
if
|
|
256
|
+
if t_projected is not None:
|
|
253
257
|
exp_avg_projected = project(exp_avg, state['Q'])
|
|
254
258
|
|
|
255
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
256
|
-
|
|
257
259
|
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
258
260
|
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
259
261
|
|
|
260
262
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
261
263
|
# to the original space
|
|
262
264
|
update = exp_avg_projected / denom
|
|
265
|
+
|
|
263
266
|
if t_projected is not None:
|
|
264
267
|
update = project_back(update, state["Q"])
|
|
265
268
|
|
|
266
|
-
if
|
|
269
|
+
if setting['bias_correction']:
|
|
267
270
|
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
268
271
|
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
269
272
|
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
@@ -279,7 +282,7 @@ class SOAP(Transform):
|
|
|
279
282
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
280
283
|
if state['GG'] is not None:
|
|
281
284
|
update_soap_covariances_(t, state['GG'], shampoo_beta)
|
|
282
|
-
if state['step'] %
|
|
283
|
-
state['Q'], state['
|
|
285
|
+
if state['step'] % setting['precond_freq'] == 0:
|
|
286
|
+
state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
|
|
284
287
|
|
|
285
288
|
return updates
|
|
@@ -2,7 +2,7 @@ from typing import Literal
|
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module, Target, Transform, Chainable,
|
|
5
|
+
from ...core import Module, Target, Transform, Chainable, apply_transform
|
|
6
6
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
8
|
|
|
@@ -35,6 +35,74 @@ def sophia_H(
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class SophiaH(Module):
|
|
38
|
+
"""SophiaH optimizer from https://arxiv.org/abs/2305.14342
|
|
39
|
+
|
|
40
|
+
This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.
|
|
41
|
+
|
|
42
|
+
.. note::
|
|
43
|
+
In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.
|
|
44
|
+
|
|
45
|
+
.. note::
|
|
46
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
47
|
+
|
|
48
|
+
.. note::
|
|
49
|
+
This module requires the a closure passed to the optimizer step,
|
|
50
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
51
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
beta1 (float, optional): first momentum. Defaults to 0.96.
|
|
55
|
+
beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
|
|
56
|
+
update_freq (int, optional):
|
|
57
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
|
|
58
|
+
precond_scale (float, optional):
|
|
59
|
+
scale of the preconditioner. Defaults to 1.
|
|
60
|
+
clip (float, optional):
|
|
61
|
+
clips update to (-clip, clip). Defaults to 1.
|
|
62
|
+
eps (float, optional):
|
|
63
|
+
clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
|
|
64
|
+
hvp_method (str, optional):
|
|
65
|
+
Determines how Hessian-vector products are evaluated.
|
|
66
|
+
|
|
67
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
68
|
+
This requires creating a graph for the gradient.
|
|
69
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
70
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
71
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
72
|
+
more accurate HVP approximation. This requires two extra
|
|
73
|
+
gradient evaluations.
|
|
74
|
+
Defaults to "autograd".
|
|
75
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
76
|
+
n_samples (int, optional):
|
|
77
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
78
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
79
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
80
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
81
|
+
|
|
82
|
+
Examples:
|
|
83
|
+
Using SophiaH:
|
|
84
|
+
|
|
85
|
+
.. code-block:: python
|
|
86
|
+
|
|
87
|
+
opt = tz.Modular(
|
|
88
|
+
model.parameters(),
|
|
89
|
+
tz.m.SophiaH(),
|
|
90
|
+
tz.m.LR(0.1)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
|
|
94
|
+
Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
|
|
95
|
+
SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
96
|
+
|
|
97
|
+
.. code-block:: python
|
|
98
|
+
|
|
99
|
+
opt = tz.Modular(
|
|
100
|
+
model.parameters(),
|
|
101
|
+
tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
|
|
102
|
+
tz.m.LR(0.1)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
"""
|
|
38
106
|
def __init__(
|
|
39
107
|
self,
|
|
40
108
|
beta1: float = 0.96,
|
|
@@ -56,8 +124,8 @@ class SophiaH(Module):
|
|
|
56
124
|
self.set_child('inner', inner)
|
|
57
125
|
|
|
58
126
|
@torch.no_grad
|
|
59
|
-
def step(self,
|
|
60
|
-
params =
|
|
127
|
+
def step(self, var):
|
|
128
|
+
params = var.params
|
|
61
129
|
settings = self.settings[params[0]]
|
|
62
130
|
hvp_method = settings['hvp_method']
|
|
63
131
|
fd_h = settings['fd_h']
|
|
@@ -71,37 +139,26 @@ class SophiaH(Module):
|
|
|
71
139
|
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
72
140
|
generator = self.global_state['generator']
|
|
73
141
|
|
|
74
|
-
beta1, beta2, precond_scale, clip, eps = self.get_settings(
|
|
75
|
-
'beta1', 'beta2', 'precond_scale', 'clip', 'eps',
|
|
142
|
+
beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
|
|
143
|
+
'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
|
|
76
144
|
|
|
77
|
-
exp_avg, h_exp_avg = self.get_state('exp_avg', 'h_exp_avg',
|
|
145
|
+
exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
78
146
|
|
|
79
147
|
step = self.global_state.get('step', 0)
|
|
80
148
|
self.global_state['step'] = step + 1
|
|
81
149
|
|
|
82
|
-
closure =
|
|
150
|
+
closure = var.closure
|
|
83
151
|
assert closure is not None
|
|
84
152
|
|
|
85
153
|
h = None
|
|
86
154
|
if step % update_freq == 0:
|
|
87
155
|
|
|
88
|
-
|
|
156
|
+
rgrad=None
|
|
89
157
|
for i in range(n_samples):
|
|
90
158
|
u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
|
|
91
159
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
assert grad is not None
|
|
95
|
-
Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
|
|
96
|
-
|
|
97
|
-
elif hvp_method == 'forward':
|
|
98
|
-
loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=vars.get_grad(), normalize=True)
|
|
99
|
-
|
|
100
|
-
elif hvp_method == 'central':
|
|
101
|
-
loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
|
|
102
|
-
|
|
103
|
-
else:
|
|
104
|
-
raise ValueError(hvp_method)
|
|
160
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
161
|
+
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
105
162
|
|
|
106
163
|
if h is None: h = Hvp
|
|
107
164
|
else: torch._foreach_add_(h, Hvp)
|
|
@@ -109,11 +166,11 @@ class SophiaH(Module):
|
|
|
109
166
|
assert h is not None
|
|
110
167
|
if n_samples > 1: torch._foreach_div_(h, n_samples)
|
|
111
168
|
|
|
112
|
-
update =
|
|
169
|
+
update = var.get_update()
|
|
113
170
|
if 'inner' in self.children:
|
|
114
|
-
update =
|
|
171
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
115
172
|
|
|
116
|
-
|
|
173
|
+
var.update = sophia_H(
|
|
117
174
|
tensors=TensorList(update),
|
|
118
175
|
h=TensorList(h) if h is not None else None,
|
|
119
176
|
exp_avg_=exp_avg,
|
|
@@ -126,4 +183,4 @@ class SophiaH(Module):
|
|
|
126
183
|
eps=eps,
|
|
127
184
|
step=step,
|
|
128
185
|
)
|
|
129
|
-
return
|
|
186
|
+
return var
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from .projection import
|
|
2
|
-
from .
|
|
3
|
-
from .structural import VectorProjection, TensorizeProjection, BlockPartition, TensorNormsProjection
|
|
4
|
-
|
|
1
|
+
from .projection import ProjectionBase, VectorProjection, ScalarProjection
|
|
2
|
+
from .cast import To, ViewAsReal
|
|
5
3
|
# from .galore import GaLore
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .projection import ProjectionBase
|
|
3
|
+
from ...core import Chainable
|
|
4
|
+
|
|
5
|
+
class To(ProjectionBase):
|
|
6
|
+
"""Cast modules to specified device and dtype"""
|
|
7
|
+
def __init__(self, modules: Chainable, dtype: torch.dtype | None, device:torch.types.Device | None = None):
|
|
8
|
+
defaults = dict(dtype=dtype, device=device)
|
|
9
|
+
super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=defaults)
|
|
10
|
+
|
|
11
|
+
@torch.no_grad
|
|
12
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
13
|
+
casted = []
|
|
14
|
+
for tensor, state, setting in zip(tensors,states, settings):
|
|
15
|
+
state['dtype'] = tensor.dtype
|
|
16
|
+
state['device'] = tensor.device
|
|
17
|
+
tensor = tensor.to(dtype=setting['dtype'], device=setting['device'])
|
|
18
|
+
casted.append(tensor)
|
|
19
|
+
return casted
|
|
20
|
+
|
|
21
|
+
@torch.no_grad
|
|
22
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
23
|
+
uncasted = []
|
|
24
|
+
for tensor, state in zip(projected_tensors, states):
|
|
25
|
+
tensor = tensor.to(dtype=state['dtype'], device=state['device'])
|
|
26
|
+
uncasted.append(tensor)
|
|
27
|
+
return uncasted
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ViewAsReal(ProjectionBase):
|
|
31
|
+
"""View complex tensors as real tensors. Doesn't affect tensors that are already."""
|
|
32
|
+
def __init__(self, modules: Chainable):
|
|
33
|
+
super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=None)
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
37
|
+
views = []
|
|
38
|
+
for tensor, state in zip(tensors,states):
|
|
39
|
+
is_complex = torch.is_complex(tensor)
|
|
40
|
+
state['is_complex'] = is_complex
|
|
41
|
+
if is_complex: tensor = torch.view_as_real(tensor)
|
|
42
|
+
views.append(tensor)
|
|
43
|
+
return views
|
|
44
|
+
|
|
45
|
+
@torch.no_grad
|
|
46
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
47
|
+
un_views = []
|
|
48
|
+
for tensor, state in zip(projected_tensors, states):
|
|
49
|
+
if state['is_complex']: tensor = torch.view_as_complex(tensor)
|
|
50
|
+
un_views.append(tensor)
|
|
51
|
+
return un_views
|