torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1390 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
# this file is from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py ver from Sept., 2025
|
|
3
|
+
# with few minor modifications (like passing Q balancing probability)
|
|
4
|
+
"""
|
|
5
|
+
The new PSGD-Kron Newton/Whitening preconditioners support five kinds of local coordinates for updating Q:
|
|
6
|
+
|
|
7
|
+
QUAD): It's a specific form for updating Q to ensure that Q > 0 (thus Q is symmetric/Hermitian).
|
|
8
|
+
This is one of the recommended choices for fitting Q.
|
|
9
|
+
|
|
10
|
+
QEQ): dQ = Q * mathcal{E} * Q
|
|
11
|
+
This leads to another simple way for updating Q (Q is in the general linear group).
|
|
12
|
+
It's another recommended choice for fitting Q.
|
|
13
|
+
|
|
14
|
+
Q0p5EQ1p5): dQ = Q^0.5 * mathcal{E} * Q^1.5
|
|
15
|
+
One more recommended choice for fitting Q.
|
|
16
|
+
An online orthogonal Procrustes problem solver is used to keep Q approximately SPD.
|
|
17
|
+
|
|
18
|
+
EQ): dQ = mathcal{E} * Q
|
|
19
|
+
This choice recovers the old PSGD way for updating Q in Lie groups (Q is triangular).
|
|
20
|
+
Its main drawback is that triangualr solvers are required for updating Q.
|
|
21
|
+
|
|
22
|
+
QEP): dQ = Q * mathcal{E} * P
|
|
23
|
+
This last choice works very well if it does. Q is in the general linear group.
|
|
24
|
+
But, one drawback is that Q might get stuck around ill-conditioned matrices (not strongly convex).
|
|
25
|
+
|
|
26
|
+
The QUAD formulae can be used to update P directly (see older commit 0fc33cd).
|
|
27
|
+
I call this choice QUAD4P. It still is a good choice for optimization with single precision.
|
|
28
|
+
Unlike QUAD, QUAD4P does not work well with half precision. Use it with caution.
|
|
29
|
+
|
|
30
|
+
The PSGD-LRA Newton/Whitening preconditioners still adopt local coordinate dQ = mathcal{E} * Q,
|
|
31
|
+
and needs a small linear solver to update the preconditioner.
|
|
32
|
+
|
|
33
|
+
I also keep the PSGD dense matrix Newton-type preconditioner here to illustrate the math.
|
|
34
|
+
It supports all the five methods for updating Q,
|
|
35
|
+
and can be a good alternative to the BFGS like quasi-Newton optimizers as no line search is required.
|
|
36
|
+
|
|
37
|
+
Xi-Lin Li, lixilinx@gmail.com; last updated in Sept., 2025.
|
|
38
|
+
Main refs: https://arxiv.org/abs/1512.04202; https://arxiv.org/abs/2402.11858.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from typing import TYPE_CHECKING, cast
|
|
42
|
+
|
|
43
|
+
import torch
|
|
44
|
+
|
|
45
|
+
from ....utils.python_tools import LazyLoader
|
|
46
|
+
|
|
47
|
+
opt_einsum = LazyLoader("opt_einsum")
|
|
48
|
+
|
|
49
|
+
if TYPE_CHECKING:
|
|
50
|
+
import opt_einsum as _opt_einsum
|
|
51
|
+
opt_einsum = cast(_opt_einsum, opt_einsum)
|
|
52
|
+
|
|
53
|
+
def norm_lower_bound_spd(A, k=32, half_iters=2):
|
|
54
|
+
"""
|
|
55
|
+
Returns a cheap lower bound for the spectral norm of a symmetric positive definite matrix A, where,
|
|
56
|
+
k: the dim of subspace, suggesting 128 for bfloat16, 32 for float32 and 4 for float64 (tested on my laptop 4070 GPU);
|
|
57
|
+
half_iters: half of the number of subspace iterations, suggesting 2.
|
|
58
|
+
A rough norm estimation is good, and we don't orthonormaliz the subspace vectors.
|
|
59
|
+
|
|
60
|
+
The initial noise space V is rotated such that its centroid aligns with the largest row of A.
|
|
61
|
+
Hence, each row of V and the largest row of A has an angle about acos(1/sqrt(k)) when k << dim(A).
|
|
62
|
+
This feature makes the subspace iteration more robust for large matrices with very low rank.
|
|
63
|
+
A simplified branchless approximate implementation is provided here.
|
|
64
|
+
"""
|
|
65
|
+
smallest_normal = torch.finfo(A.dtype).smallest_normal
|
|
66
|
+
normalizing_factor = A.diagonal().real.amax() + smallest_normal
|
|
67
|
+
A = A / normalizing_factor # (complex tensor) / (subnormal number) could produce inf or nan unexpectedly
|
|
68
|
+
j = torch.argmax(torch.linalg.vector_norm(A, dim=1))
|
|
69
|
+
V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device)
|
|
70
|
+
V = A[j] + torch.sgn(torch.sum(A[j] * V.conj(), dim=1, keepdim=True)) * V # torch.sign for real
|
|
71
|
+
for _ in range(half_iters):
|
|
72
|
+
V = V @ A
|
|
73
|
+
V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + smallest_normal
|
|
74
|
+
V = V @ A
|
|
75
|
+
return normalizing_factor * torch.amax(torch.linalg.vector_norm(V, dim=1))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def norm_lower_bound_skh(A, k=32, half_iters=2):
|
|
79
|
+
"""
|
|
80
|
+
Returns a cheap lower bound for the spectral norm of a skew-Hermitian matrix A,
|
|
81
|
+
k: the dim of subspace, suggesting 128 for bfloat16, 32 for float32 and 4 for float64 (tested on my laptop 4070 GPU);
|
|
82
|
+
half_iters: half of the number of subspace iterations, suggesting 2.
|
|
83
|
+
A rough norm estimation is good, and we don't orthonormaliz the subspace vectors.
|
|
84
|
+
|
|
85
|
+
The initial noise space V is rotated such that its centroid aligns with the largest row of A.
|
|
86
|
+
Hence, each row of V and the largest row of A has an angle about acos(1/sqrt(k)) when k << dim(A).
|
|
87
|
+
This feature makes the subspace iteration more robust for large matrices with very low rank.
|
|
88
|
+
A simplified branchless approximate implementation is provided here.
|
|
89
|
+
"""
|
|
90
|
+
smallest_normal = torch.finfo(A.dtype).smallest_normal
|
|
91
|
+
normalizing_factor = A.abs().amax() + smallest_normal
|
|
92
|
+
A = A / normalizing_factor # (complex tensor) / (subnormal number) could produce inf or nan unexpectedly
|
|
93
|
+
j = torch.argmax(torch.linalg.vector_norm(A, dim=1))
|
|
94
|
+
V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device)
|
|
95
|
+
V = A[j] + torch.sgn(torch.sum(A[j] * V.conj(), dim=1, keepdim=True)) * V # torch.sign for real
|
|
96
|
+
for _ in range(half_iters):
|
|
97
|
+
V = V @ A
|
|
98
|
+
V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + smallest_normal
|
|
99
|
+
V = V @ A
|
|
100
|
+
return normalizing_factor * torch.amax(torch.linalg.vector_norm(V, dim=1))
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def lift2single(x):
|
|
104
|
+
# lift half or lower precision to single precision; leave single or higher precision unchanged
|
|
105
|
+
return x.to(torch.float32) if torch.finfo(x.dtype).eps > 1e-6 else x
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def procrustes_step(Q, max_step_size=1/8):
|
|
109
|
+
"""
|
|
110
|
+
A in-place (update Q directly) online solver for the orthogonal Procrustes problem,
|
|
111
|
+
min_U || U Q - I ||_F, s.t. U^H U = I
|
|
112
|
+
by rotating Q as exp(a R) Q, where R = Q^H - Q is the generator and ||a R|| < 1.
|
|
113
|
+
|
|
114
|
+
Do not set max_step_size > 1/4 as we only expand exp(a R) to its 2nd term.
|
|
115
|
+
|
|
116
|
+
Note that U(n) is connected and such rotations can make most complex Q SPD except for convergence to saddle points.
|
|
117
|
+
However, O(n) is not connected. Hence, such SO(n) rotations can only make real Q with det(Q) > 0 SPD.
|
|
118
|
+
|
|
119
|
+
We have simplified the original implementation. The one branch here is necessary for line search.
|
|
120
|
+
"""
|
|
121
|
+
R = Q.H - Q
|
|
122
|
+
R /= norm_lower_bound_skh(R) + torch.finfo(R.dtype).smallest_normal # normalize R as typically it's too small
|
|
123
|
+
RQ = R @ Q
|
|
124
|
+
RRQ = R @ RQ
|
|
125
|
+
tr_RQ = RQ.diagonal().real.sum() # tr_RQ >=0 by theory; torch.trace not implemented for CPU bfloat16, so sum(diag())
|
|
126
|
+
tr_RRQ = RRQ.diagonal().real.sum() # line search is needed if tr_RRQ < 0
|
|
127
|
+
a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size), max_step_size)
|
|
128
|
+
Q.add_(a * (RQ + 0.5 * a * RRQ))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
############# Begin of PSGD Kronecker product preconditioners #############
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def init_kron(t, Scale=1.0, max_size=float("inf"), max_skew=1.0, dQ="QEQ"):
|
|
135
|
+
"""
|
|
136
|
+
For a scalar or tensor t, we initialize its states (preconditioner Q and Lipschitz smoothness constant L),
|
|
137
|
+
and reusable contraction expressions for updating Q and preconditioning gradient.
|
|
138
|
+
|
|
139
|
+
1, The preconditioner Q is initialized to
|
|
140
|
+
Q = Scale * I = Scale * kron(eye(t.shape[0]), eye(t.shape[1]), ...)
|
|
141
|
+
where the eye(.) may be replaced with diag(ones(.)) if that dim is too large, determined by max_size and max_skew.
|
|
142
|
+
|
|
143
|
+
The Lipschitz smoothness constant L for Q is initialized to zero.
|
|
144
|
+
|
|
145
|
+
2, A series of enisum contract expressions. The following subscript examples are for a 5th order tensor.
|
|
146
|
+
2.1, exprP is the expression for applying the Preconditioner on the gradient, e.g.,
|
|
147
|
+
'aA,bB,cC,dD,eE,aα,bβ,cγ,dδ,eε,αβγδε->ABCDE'
|
|
148
|
+
2.2, the i-th expression of exprGs is for the contraction of two tensors that only keeps the i-th dim, e.g.,
|
|
149
|
+
'abCde,abγde->Cγ'
|
|
150
|
+
for i=2. It's useful for Gradient calculation.
|
|
151
|
+
2.3, exprA is the expression for applying All the factors of Q on a tensor, e.g.,
|
|
152
|
+
'aA,bB,cC,dD,eE,ABCDE->abcde'
|
|
153
|
+
2.4, the i-th expression of exprQs is the expression for applying the i-th factor of Q on a tensor, e.g.,
|
|
154
|
+
'Cγ,abγde->abCde'
|
|
155
|
+
for i=2.
|
|
156
|
+
|
|
157
|
+
Please check https://drive.google.com/file/d/1CEEq7A3_l8EcPEDa_sYtqr5aMLVeZWL7/view?usp=drive_link for notations and derivations.
|
|
158
|
+
"""
|
|
159
|
+
if dQ == "QUAD4P": # the only case that we fit P directly; so square Scale
|
|
160
|
+
Scale = Scale ** 2
|
|
161
|
+
shape = t.shape
|
|
162
|
+
if len(shape)==0: # scalar
|
|
163
|
+
Q = [Scale * torch.ones_like(t),]
|
|
164
|
+
L = [lift2single(torch.zeros_like(t.real)),]
|
|
165
|
+
exprA = opt_einsum.contract_expression(",->", Q[0].shape, t.shape)
|
|
166
|
+
exprP = opt_einsum.contract_expression(",,->", Q[0].shape, Q[0].shape, t.shape)
|
|
167
|
+
exprGs = [opt_einsum.contract_expression(",->", t.shape, t.shape),]
|
|
168
|
+
exprQs = [opt_einsum.contract_expression(",->", Q[0].shape, t.shape),]
|
|
169
|
+
else: # tensor
|
|
170
|
+
if len(shape) > 26:
|
|
171
|
+
raise ValueError(f"Got tensor with dim {len(t.shape)}; einsum runs out of letters; replace 26 with larger numbers.")
|
|
172
|
+
|
|
173
|
+
scale = Scale ** (1/len(shape))
|
|
174
|
+
|
|
175
|
+
Q, L = [], []
|
|
176
|
+
exprGs, exprQs = [], []
|
|
177
|
+
piece1A, piece2A, piece3A = [], "", "" # used for getting the subscripts for exprA
|
|
178
|
+
piece1P, piece2P, piece3P, piece4P = [], [], "", "" # used for getting the subscripts for exprP
|
|
179
|
+
for i, size in enumerate(shape):
|
|
180
|
+
L.append(lift2single(torch.zeros([], dtype=t.real.dtype, device=t.device)))
|
|
181
|
+
if size <= 1 or size > max_size or size**2 > max_skew * t.numel():
|
|
182
|
+
# use diagonal matrix as preconditioner for this dim
|
|
183
|
+
Q.append(scale * torch.ones(size, dtype=t.dtype, device=t.device))
|
|
184
|
+
|
|
185
|
+
piece1A.append(opt_einsum.get_symbol(i))
|
|
186
|
+
piece2A = piece2A + opt_einsum.get_symbol(i)
|
|
187
|
+
piece3A = piece3A + opt_einsum.get_symbol(i)
|
|
188
|
+
|
|
189
|
+
piece1P.append(opt_einsum.get_symbol(i + 26))
|
|
190
|
+
piece2P.append(opt_einsum.get_symbol(i + 26))
|
|
191
|
+
piece3P = piece3P + opt_einsum.get_symbol(i + 26)
|
|
192
|
+
piece4P = piece4P + opt_einsum.get_symbol(i + 26)
|
|
193
|
+
|
|
194
|
+
piece1 = "".join([opt_einsum.get_symbol(i+26) if j==i else opt_einsum.get_symbol(j) for j in range(len(shape))])
|
|
195
|
+
subscripts = piece1 + "," + piece1 + "->" + opt_einsum.get_symbol(i+26)
|
|
196
|
+
exprGs.append(opt_einsum.contract_expression(subscripts, t.shape, t.shape))
|
|
197
|
+
|
|
198
|
+
subscripts = opt_einsum.get_symbol(i+26) + "," + piece1 + "->" + piece1
|
|
199
|
+
exprQs.append(opt_einsum.contract_expression(subscripts, Q[-1].shape, t.shape))
|
|
200
|
+
else: # use matrix preconditioner for this dim
|
|
201
|
+
Q.append(scale * torch.eye(size, dtype=t.dtype, device=t.device))
|
|
202
|
+
|
|
203
|
+
piece1A.append(opt_einsum.get_symbol(i) + opt_einsum.get_symbol(i + 26))
|
|
204
|
+
piece2A = piece2A + opt_einsum.get_symbol(i + 26)
|
|
205
|
+
piece3A = piece3A + opt_einsum.get_symbol(i)
|
|
206
|
+
|
|
207
|
+
a, b, c = opt_einsum.get_symbol(i), opt_einsum.get_symbol(i + 26), opt_einsum.get_symbol(i + 805)
|
|
208
|
+
piece1P.append(a + b)
|
|
209
|
+
piece2P.append(a + c)
|
|
210
|
+
piece3P = piece3P + c
|
|
211
|
+
piece4P = piece4P + b
|
|
212
|
+
|
|
213
|
+
piece1 = "".join([opt_einsum.get_symbol(i+26) if j==i else opt_einsum.get_symbol(j) for j in range(len(shape))])
|
|
214
|
+
piece2 = "".join([opt_einsum.get_symbol(i+805) if j==i else opt_einsum.get_symbol(j) for j in range(len(shape))])
|
|
215
|
+
subscripts = piece1 + "," + piece2 + "->" + opt_einsum.get_symbol(i+26) + opt_einsum.get_symbol(i+805)
|
|
216
|
+
exprGs.append(opt_einsum.contract_expression(subscripts, t.shape, t.shape))
|
|
217
|
+
|
|
218
|
+
subscripts = opt_einsum.get_symbol(i+26) + opt_einsum.get_symbol(i+805) + "," + piece2 + "->" + piece1
|
|
219
|
+
exprQs.append(opt_einsum.contract_expression(subscripts, Q[-1].shape, t.shape))
|
|
220
|
+
|
|
221
|
+
subscripts = ",".join(piece1A) + "," + piece2A + "->" + piece3A
|
|
222
|
+
exprA = opt_einsum.contract_expression(subscripts, *[q.shape for q in Q], t.shape)
|
|
223
|
+
|
|
224
|
+
subscripts = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
|
|
225
|
+
exprP = opt_einsum.contract_expression(subscripts, *[q.shape for q in Q], *[q.shape for q in Q], t.shape)
|
|
226
|
+
|
|
227
|
+
exprGs, exprQs = tuple(exprGs), tuple(exprQs)
|
|
228
|
+
if dQ == "QEP":
|
|
229
|
+
return [[Q, L], (exprP, exprGs, exprQs)]
|
|
230
|
+
elif dQ == "EQ":
|
|
231
|
+
return [[Q, L], (exprP, exprGs, exprA)]
|
|
232
|
+
elif (dQ == "QEQ") or (dQ == "QUAD") or (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"):
|
|
233
|
+
return [[Q, L], (exprP, exprGs)]
|
|
234
|
+
else: # the only case that we fit P directly
|
|
235
|
+
assert dQ == "QUAD4P", "Invalid choice for dQ"
|
|
236
|
+
return [[Q, L], (exprA, exprGs)]
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def balance_kron_precond(Q):
|
|
240
|
+
"""
|
|
241
|
+
In place balancing the dynamic ranges of the factors of Q to avoid over/under-flow.
|
|
242
|
+
"""
|
|
243
|
+
order = len(Q) # order of tensor or the number of factors in Q
|
|
244
|
+
if order>1:
|
|
245
|
+
norms = [torch.max(torch.abs(q)) for q in Q]
|
|
246
|
+
gmean = torch.prod(torch.stack(norms))**(1/order) # geometric mean
|
|
247
|
+
for i, q in enumerate(Q):
|
|
248
|
+
q.mul_(gmean/norms[i])
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def update_precond_kron_eq(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, balance_prob=0.01):
|
|
252
|
+
"""
|
|
253
|
+
The raw function for updating the Kron preconditioner Q and Lipschitz smoothness constant L with pair (V, Hvp),
|
|
254
|
+
where Q is update as dQ = E*Q,
|
|
255
|
+
the pair (V, Hvp) can be (vector, hess-vector-prod) or (randn, gradient/momentum).
|
|
256
|
+
The damping logic is not included here.
|
|
257
|
+
"""
|
|
258
|
+
Q, L = QL
|
|
259
|
+
_, exprGs, exprA = exprs
|
|
260
|
+
|
|
261
|
+
def solve_triangular_right(B, A):
|
|
262
|
+
# return B @ inv(A)
|
|
263
|
+
if B.dim()>1:
|
|
264
|
+
return torch.linalg.solve_triangular(lift2single(A), lift2single(B), upper=True, left=False).to(B.dtype)
|
|
265
|
+
else: # torch.linalg.solve_triangular complains if B.dim() < 2. So insert None.
|
|
266
|
+
return (torch.linalg.solve_triangular(lift2single(A), lift2single(B[None,:]), upper=True, left=False)[0]).to(B.dtype)
|
|
267
|
+
|
|
268
|
+
A = exprA(*Q, Hvp)
|
|
269
|
+
|
|
270
|
+
order = V.dim()
|
|
271
|
+
p = list(range(order))
|
|
272
|
+
conjB = torch.permute(V.conj(), p[1:] + p[:1]) # permute dims like [0,1,2,3,4] -> [1,2,3,4,0]
|
|
273
|
+
for i, q in enumerate(Q):
|
|
274
|
+
conjB = conjB/q if q.dim()<2 else solve_triangular_right(conjB, q)
|
|
275
|
+
if i < order - 1: # transpose dims like [1,2,3,4,0]->[0,2,3,4,1]->[0,1,3,4,2]->[0,1,2,4,3]->[0,1,2,3,4]
|
|
276
|
+
conjB = torch.transpose(conjB, i, order - 1)
|
|
277
|
+
|
|
278
|
+
for i, q in enumerate(Q):
|
|
279
|
+
term1 = exprGs[i](A, A.conj())
|
|
280
|
+
term2 = exprGs[i](conjB.conj(), conjB)
|
|
281
|
+
|
|
282
|
+
if q.dim() < 2: # q is a diagonal matrix or scalar preconditioner
|
|
283
|
+
ell = torch.max(torch.real(term1 + term2))
|
|
284
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
285
|
+
q.sub_(lr/L[i] * (term1 - term2) * q) # q.mul_(1 - lr/L[i] * (term1 - term2)): larger roundoff errors
|
|
286
|
+
else: # q is a matrix preconditioner
|
|
287
|
+
ell = norm_lower_bound_spd(term1 + term2)
|
|
288
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
289
|
+
q.sub_(lr/L[i] * torch.triu(term1 - term2) @ q)
|
|
290
|
+
|
|
291
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
292
|
+
balance_kron_precond(Q)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def precond_grad_kron(QL, exprs, G):
|
|
296
|
+
"""
|
|
297
|
+
Precondition gradient G with Kron preconditioner Q.
|
|
298
|
+
"""
|
|
299
|
+
Q, exprP = QL[0], exprs[0]
|
|
300
|
+
return exprP(*[q.conj() for q in Q], *Q, G)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def update_precond_kron_whiten_eq(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
304
|
+
"""
|
|
305
|
+
Update the Kron preconditioner Q as dQ = E*Q.
|
|
306
|
+
"""
|
|
307
|
+
V = torch.randn_like(G)
|
|
308
|
+
update_precond_kron_eq(QL, exprs, V, G + damping*V, lr=lr, betaL=betaL, balance_prob=balance_prob)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def update_precond_kron_whiten_qep(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=...):
|
|
312
|
+
"""
|
|
313
|
+
Update the Kron preconditioner Q as dQ = Q*E*P.
|
|
314
|
+
"""
|
|
315
|
+
Q, L = QL
|
|
316
|
+
exprP, exprGs, exprQs = exprs
|
|
317
|
+
|
|
318
|
+
# balancing is not optional as L for each factor is not scaling invariant
|
|
319
|
+
balance_kron_precond(Q)
|
|
320
|
+
|
|
321
|
+
total_numel = G.numel()
|
|
322
|
+
Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
|
|
323
|
+
for i, q in enumerate(Q):
|
|
324
|
+
QPg = exprQs[i](q, Pg)
|
|
325
|
+
term1 = exprGs[i](QPg, QPg.conj())
|
|
326
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
327
|
+
term2 = total_numel/q.numel() * q * q.conj()
|
|
328
|
+
ell = torch.max(torch.real(term1 + term2))
|
|
329
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
330
|
+
q.mul_(1 - lr/L[i] * (term1 - term2))
|
|
331
|
+
else: # matrix Q
|
|
332
|
+
term2 = total_numel/q.shape[0] * q @ q.H
|
|
333
|
+
ell = norm_lower_bound_spd(term1 + term2)
|
|
334
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
335
|
+
q.sub_(lr/L[i] * (term1 - term2) @ q)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def update_precond_kron_whiten_qeq(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
339
|
+
"""
|
|
340
|
+
Update the Kron preconditioner Q as dQ = Q*E*Q.
|
|
341
|
+
"""
|
|
342
|
+
Q, L = QL
|
|
343
|
+
exprP, exprGs = exprs
|
|
344
|
+
|
|
345
|
+
total_numel = G.numel()
|
|
346
|
+
Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
|
|
347
|
+
for i, q in enumerate(Q):
|
|
348
|
+
term1 = exprGs[i](Pg, Pg.conj())
|
|
349
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
350
|
+
term2 = total_numel/q.numel() # times I
|
|
351
|
+
ell = torch.max(torch.real(term1)) + term2
|
|
352
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
353
|
+
q.mul_(1 - lr/L[i] * (term1 - term2))
|
|
354
|
+
else: # matrix Q
|
|
355
|
+
term2 = total_numel/q.shape[0] # times I
|
|
356
|
+
ell = norm_lower_bound_spd(term1) + term2
|
|
357
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
358
|
+
q.sub_(lr/L[i] * (q @ term1 - q * term2))
|
|
359
|
+
|
|
360
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
361
|
+
balance_kron_precond(Q)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def update_precond_kron_whiten_q0p5eq1p5(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
365
|
+
"""
|
|
366
|
+
Update the Kron preconditioner Q as dQ = Q^0.5 * E * Q^1.5.
|
|
367
|
+
"""
|
|
368
|
+
Q, L = QL
|
|
369
|
+
exprP, exprGs = exprs
|
|
370
|
+
|
|
371
|
+
total_numel = G.numel()
|
|
372
|
+
Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
|
|
373
|
+
for i, q in enumerate(Q):
|
|
374
|
+
term1 = exprGs[i](Pg, Pg.conj())
|
|
375
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
376
|
+
term2 = total_numel/q.numel() # times I
|
|
377
|
+
ell = torch.max(torch.real(term1)) + term2
|
|
378
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
379
|
+
q.mul_(1 - lr/L[i] * (term1 - term2))
|
|
380
|
+
else: # matrix Q
|
|
381
|
+
term2 = total_numel/q.shape[0] # times I
|
|
382
|
+
ell = norm_lower_bound_spd(term1) + term2
|
|
383
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
384
|
+
q.sub_(lr/L[i] * (term1 @ q - term2 * q))
|
|
385
|
+
procrustes_step(q)
|
|
386
|
+
|
|
387
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
388
|
+
balance_kron_precond(Q)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def update_precond_kron_whiten_quad(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
392
|
+
"""
|
|
393
|
+
Update the Kron preconditioner Q with a quadratic form.
|
|
394
|
+
"""
|
|
395
|
+
Q, L = QL
|
|
396
|
+
exprP, exprGs = exprs
|
|
397
|
+
|
|
398
|
+
total_numel = G.numel()
|
|
399
|
+
Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
|
|
400
|
+
for i, q in enumerate(Q):
|
|
401
|
+
term1 = exprGs[i](Pg, Pg.conj())
|
|
402
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
403
|
+
term2 = total_numel/q.numel() # times I
|
|
404
|
+
ell = torch.max(torch.real(term1)) + term2
|
|
405
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
406
|
+
gain = 1 - lr/2/L[i] * (term1 - term2)
|
|
407
|
+
q.mul_(gain * gain)
|
|
408
|
+
else: # matrix Q
|
|
409
|
+
term2 = total_numel/q.shape[0] # times I
|
|
410
|
+
ell = norm_lower_bound_spd(term1) + term2
|
|
411
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
412
|
+
p = q - lr/2/L[i] * (term1 @ q - term2 * q)
|
|
413
|
+
p = p - lr/2/L[i] * (p @ term1 - p * term2)
|
|
414
|
+
q.copy_((p + p.H)/2) # p must be symmetric/hermitian
|
|
415
|
+
|
|
416
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
417
|
+
balance_kron_precond(Q)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def update_precond_kron_whiten_quad4p(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
421
|
+
"""
|
|
422
|
+
Almost the same as function update_precond_kron_whiten_quad except that fitting P directly.
|
|
423
|
+
This is the only case that we fit P directly (Q here is P). Vulnerable to numerical errors.
|
|
424
|
+
"""
|
|
425
|
+
Q, L = QL
|
|
426
|
+
exprA, exprGs = exprs
|
|
427
|
+
|
|
428
|
+
total_numel = G.numel()
|
|
429
|
+
Pg = exprA(*Q, G + damping*torch.randn_like(G)) # Q actually is P; so just applying all its factors once.
|
|
430
|
+
for i, q in enumerate(Q):
|
|
431
|
+
term1 = exprGs[i](Pg, Pg.conj())
|
|
432
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
433
|
+
term2 = total_numel/q.numel() # times I
|
|
434
|
+
ell = torch.max(torch.real(term1)) + term2
|
|
435
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
436
|
+
gain = 1 - lr/L[i] * (term1 - term2)
|
|
437
|
+
q.mul_(gain * gain)
|
|
438
|
+
else: # matrix Q
|
|
439
|
+
term2 = total_numel/q.shape[0] # times I
|
|
440
|
+
ell = norm_lower_bound_spd(term1) + term2
|
|
441
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
442
|
+
p = q - lr/L[i] * (term1 @ q - term2 * q)
|
|
443
|
+
p = p - lr/L[i] * (p @ term1 - p * term2)
|
|
444
|
+
q.copy_((p + p.H)/2) # p must be symmetric/hermitian
|
|
445
|
+
|
|
446
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
447
|
+
balance_kron_precond(Q)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class KronWhiten:
|
|
451
|
+
"""
|
|
452
|
+
Implements the PSGD optimizer with the Kronecker product gradient/momentum whitening preconditioner.
|
|
453
|
+
Most of the time, the hyperparameter name says it all. Here are some comments on a few key hyperparameters.
|
|
454
|
+
|
|
455
|
+
1, preconditioner_max_size and preconditioner_max_skew. These two together control the complexity of the preconditioners.
|
|
456
|
+
For example, we are to precondition a 2D gradient with shape 10 x 50.
|
|
457
|
+
With preconditioner_max_size 20, we use a dense preconditioner for the first dim since 10 <= 20 and diagonal preconditioner for the second dim since 50 > 20.
|
|
458
|
+
With preconditioner_max_skew 1.5, we use a dense preconditioner for the first dim since 10/50 <= 1.5 and diagonal preconditioner for the second dim since 50/10 > 1.5.
|
|
459
|
+
|
|
460
|
+
2, grad_clip_max_amp, betaL and damping. These three together help to stabilize the training.
|
|
461
|
+
PSGD here tries to normalize the gradients to unit amplitude. This can be problematic when gradients approach zeros.
|
|
462
|
+
The most effective way is to clip the preconditioned gradients if their amplitudes exceed grad_clip_max_amp, say 1.0.
|
|
463
|
+
Another way is to damp and upper bound the fitted preconditioner such that P < eye/damping.
|
|
464
|
+
For extremely sparse gradients, increasing betaL (say to 0.999) helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
|
|
465
|
+
|
|
466
|
+
3, Lastly, dQ is for the selection of geometry for preconditioner update. QEQ, QUAD and Q0p5EQ1p5 all are good choices.
|
|
467
|
+
Q is initialized to preconditioner_init_scale * eye. Boolean setting whiten_grad decides to whiten whether the gradient or momentum.
|
|
468
|
+
Always good to check https://arxiv.org/abs/2402.11858 for math details.
|
|
469
|
+
"""
|
|
470
|
+
def __init__(self, params_with_grad,
|
|
471
|
+
preconditioner_max_size=float("inf"), preconditioner_max_skew=1.0, preconditioner_init_scale:float|None=None,
|
|
472
|
+
lr_params=0.001, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
|
|
473
|
+
grad_clip_max_amp=float("inf"), preconditioner_update_probability=1.0, whiten_grad=True, dQ="Q0.5EQ1.5"):
|
|
474
|
+
# mutable members
|
|
475
|
+
self.lr_params = lr_params
|
|
476
|
+
self.lr_preconditioner = lr_preconditioner
|
|
477
|
+
self.betaL = betaL # beta for the Lipschitz smoothness constant estimation; set to a large value for sparse gradients
|
|
478
|
+
self.damping = damping # to damp and upper bound the preconditioner such that P < eye/damping
|
|
479
|
+
self.momentum = momentum if (0<momentum<1) else 0.0
|
|
480
|
+
self.grad_clip_max_amp = grad_clip_max_amp # clip grad once its average amplitude exceeds this max amplitude setting
|
|
481
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
|
482
|
+
# protected members
|
|
483
|
+
self._preconditioner_max_size = preconditioner_max_size
|
|
484
|
+
self._preconditioner_max_skew = preconditioner_max_skew
|
|
485
|
+
params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
|
|
486
|
+
self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
|
|
487
|
+
self._num_params = sum([p.numel() for p in self._params_with_grad])
|
|
488
|
+
if preconditioner_init_scale is None:
|
|
489
|
+
self._QLs_exprs = None # initialize on the fly
|
|
490
|
+
print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
491
|
+
else:
|
|
492
|
+
self._QLs_exprs = [init_kron(p.squeeze(), preconditioner_init_scale, preconditioner_max_size, preconditioner_max_skew, dQ) for p in self._params_with_grad]
|
|
493
|
+
self._ms, self._counter_m = None, 0 # momentum buffers and counter
|
|
494
|
+
self._whiten_grad = whiten_grad # set to False to whiten momentum.
|
|
495
|
+
if not whiten_grad:
|
|
496
|
+
assert self.momentum > 0, "Cannot whiten momentum if the momentum setting is zero."
|
|
497
|
+
print(f"Recommend to reduce lr_params by {int(((1 + momentum)/(1 - momentum))**0.5)} times")
|
|
498
|
+
self._dQ = dQ
|
|
499
|
+
if dQ == "QUAD4P": # the only case that we fit P directly
|
|
500
|
+
assert max([torch.finfo(p.dtype).eps for p in self._params_with_grad]) < 1e-6, "Directly fitting P needs at least single precision"
|
|
501
|
+
self._update_precond = update_precond_kron_whiten_quad4p
|
|
502
|
+
self._precond_grad = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
|
|
503
|
+
else:
|
|
504
|
+
self._precond_grad = precond_grad_kron
|
|
505
|
+
if dQ == "QEP":
|
|
506
|
+
self._update_precond = update_precond_kron_whiten_qep
|
|
507
|
+
elif dQ == "EQ":
|
|
508
|
+
self._update_precond = update_precond_kron_whiten_eq
|
|
509
|
+
elif dQ == "QEQ":
|
|
510
|
+
self._update_precond = update_precond_kron_whiten_qeq
|
|
511
|
+
elif dQ == "QUAD":
|
|
512
|
+
self._update_precond = update_precond_kron_whiten_quad
|
|
513
|
+
else:
|
|
514
|
+
assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), "Invalid choice for dQ"
|
|
515
|
+
self._update_precond = update_precond_kron_whiten_q0p5eq1p5
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@torch.no_grad()
|
|
519
|
+
def step(self, closure):
|
|
520
|
+
"""
|
|
521
|
+
Performs one step of PSGD with the Kronecker product gradient/momentum whitening preconditioner.
|
|
522
|
+
"""
|
|
523
|
+
with torch.enable_grad():
|
|
524
|
+
closure_returns = closure()
|
|
525
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
526
|
+
grads = [g.squeeze() for g in torch.autograd.grad(loss, self._params_with_grad)]
|
|
527
|
+
|
|
528
|
+
if self._QLs_exprs is None:
|
|
529
|
+
scale = max([torch.mean((torch.abs(g))**4) for g in grads])
|
|
530
|
+
scale = (scale + self.damping**4)**(-1/8)
|
|
531
|
+
self._QLs_exprs = [init_kron(g, scale, self._preconditioner_max_size, self._preconditioner_max_skew, self._dQ) for g in grads]
|
|
532
|
+
|
|
533
|
+
if self.momentum > 0:
|
|
534
|
+
beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
|
|
535
|
+
self._counter_m += 1
|
|
536
|
+
if self._ms is None:
|
|
537
|
+
self._ms = [torch.zeros_like(g) for g in grads]
|
|
538
|
+
|
|
539
|
+
[m.mul_(beta).add_(g, alpha=1 - beta) for (m, g) in zip(self._ms, grads)]
|
|
540
|
+
else:
|
|
541
|
+
self._ms, self._counter_m = None, 0
|
|
542
|
+
|
|
543
|
+
if torch.rand([]) < self.preconditioner_update_probability: # update Q
|
|
544
|
+
if self._whiten_grad: # Q whitens gradient
|
|
545
|
+
[self._update_precond(*QL_exprs, g, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
|
|
546
|
+
for (QL_exprs, g) in zip(self._QLs_exprs, grads)]
|
|
547
|
+
else: # Q whitens momentum
|
|
548
|
+
[self._update_precond(*QL_exprs, m, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
|
|
549
|
+
for (QL_exprs, m) in zip(self._QLs_exprs, self._ms)]
|
|
550
|
+
|
|
551
|
+
if self.momentum > 0: # precondition momentum
|
|
552
|
+
pre_grads = [self._precond_grad(*QL_exprs, m) for (QL_exprs, m) in zip(self._QLs_exprs, self._ms)]
|
|
553
|
+
else: # precondition gradient
|
|
554
|
+
pre_grads = [self._precond_grad(*QL_exprs, g) for (QL_exprs, g) in zip(self._QLs_exprs, grads)]
|
|
555
|
+
|
|
556
|
+
lr = self.lr_params
|
|
557
|
+
if self.grad_clip_max_amp < float("inf"): # clip preconditioned gradient
|
|
558
|
+
avg_amp = torch.sqrt(torch.real(sum([torch.sum(g*g.conj()) for g in pre_grads]))/self._num_params)
|
|
559
|
+
if avg_amp > self.grad_clip_max_amp:
|
|
560
|
+
lr = lr * self.grad_clip_max_amp / avg_amp
|
|
561
|
+
|
|
562
|
+
# Update the parameters.
|
|
563
|
+
[param.subtract_(lr*g.view_as(param)) for (param, g) in zip(self._params_with_grad, pre_grads)]
|
|
564
|
+
|
|
565
|
+
# return whatever closure returns
|
|
566
|
+
return closure_returns
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def update_precond_kron_newton_eq(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
570
|
+
"""
|
|
571
|
+
Update the Kron Newton-type preconditioner Q as dQ = E*Q with a pair of vector and hvp, (V, Hvp).
|
|
572
|
+
"""
|
|
573
|
+
update_precond_kron_eq(QL, exprs, V, Hvp + damping*torch.randn_like(Hvp), lr=lr, betaL=betaL, balance_prob=balance_prob)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def update_precond_kron_newton_qep(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=...):
|
|
577
|
+
"""
|
|
578
|
+
Update the Kron Newton-type preconditioner Q as dQ = Q*E*P with a pair of vector and hvp, (V, Hvp).
|
|
579
|
+
"""
|
|
580
|
+
Q, L = QL
|
|
581
|
+
exprP, exprGs, exprQs = exprs
|
|
582
|
+
|
|
583
|
+
# balancing is not optional as L for each factor is not scaling invariant
|
|
584
|
+
balance_kron_precond(Q)
|
|
585
|
+
Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
|
|
586
|
+
|
|
587
|
+
for i, q in enumerate(Q):
|
|
588
|
+
QPh = exprQs[i](q, Ph)
|
|
589
|
+
Qv = exprQs[i](q, V)
|
|
590
|
+
term1 = exprGs[i](QPh, QPh.conj())
|
|
591
|
+
term2 = exprGs[i](Qv, Qv.conj())
|
|
592
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
593
|
+
ell = torch.max(torch.real(term1 + term2))
|
|
594
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
595
|
+
q.mul_(1 - lr/L[i] * (term1 - term2))
|
|
596
|
+
else: # matrix Q
|
|
597
|
+
ell = norm_lower_bound_spd(term1 + term2)
|
|
598
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
599
|
+
q.sub_(lr/L[i] * (term1 - term2) @ q)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def update_precond_kron_newton_qeq(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
603
|
+
"""
|
|
604
|
+
Update the Kron Newton-type preconditioner Q as dQ = Q*E*Q with a pair of vector and hvp, (V, Hvp).
|
|
605
|
+
"""
|
|
606
|
+
Q, L = QL
|
|
607
|
+
exprP, exprGs = exprs
|
|
608
|
+
Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
|
|
609
|
+
|
|
610
|
+
for i, q in enumerate(Q):
|
|
611
|
+
term1 = exprGs[i](Ph, Ph.conj())
|
|
612
|
+
term2 = exprGs[i](V, V.conj())
|
|
613
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
614
|
+
ell = torch.max(torch.real(term1 + term2))
|
|
615
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
616
|
+
q.mul_(1 - lr/L[i] * (term1 - term2))
|
|
617
|
+
else: # matrix Q
|
|
618
|
+
ell = norm_lower_bound_spd(term1 + term2)
|
|
619
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
620
|
+
q.sub_(lr/L[i] * q @ (term1 - term2))
|
|
621
|
+
|
|
622
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
623
|
+
balance_kron_precond(Q)
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def update_precond_kron_newton_q0p5eq1p5(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
627
|
+
"""
|
|
628
|
+
Update the Kron Newton-type preconditioner Q as dQ = Q^0.5 * E * Q^1.5 with a pair of vector and hvp, (V, Hvp).
|
|
629
|
+
"""
|
|
630
|
+
Q, L = QL
|
|
631
|
+
exprP, exprGs = exprs
|
|
632
|
+
Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
|
|
633
|
+
|
|
634
|
+
for i, q in enumerate(Q):
|
|
635
|
+
term1 = exprGs[i](Ph, Ph.conj())
|
|
636
|
+
term2 = exprGs[i](V, V.conj())
|
|
637
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
638
|
+
ell = torch.max(torch.real(term1 + term2))
|
|
639
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
640
|
+
q.mul_(1 - lr/L[i] * (term1 - term2))
|
|
641
|
+
else: # matrix Q
|
|
642
|
+
ell = norm_lower_bound_spd(term1 + term2)
|
|
643
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
644
|
+
q.sub_(lr/L[i] * (term1 - term2) @ q)
|
|
645
|
+
procrustes_step(q)
|
|
646
|
+
|
|
647
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
648
|
+
balance_kron_precond(Q)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def update_precond_kron_newton_quad(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
652
|
+
"""
|
|
653
|
+
Update the Kron Newton-type preconditioner Q with a quadratic form for dQ and pair of vector and hvp, (V, Hvp).
|
|
654
|
+
"""
|
|
655
|
+
Q, L = QL
|
|
656
|
+
exprP, exprGs = exprs
|
|
657
|
+
Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
|
|
658
|
+
|
|
659
|
+
for i, q in enumerate(Q):
|
|
660
|
+
term1 = exprGs[i](Ph, Ph.conj())
|
|
661
|
+
term2 = exprGs[i](V, V.conj())
|
|
662
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
663
|
+
ell = torch.max(torch.real(term1 + term2))
|
|
664
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
665
|
+
gain = 1 - lr/2/L[i] * (term1 - term2)
|
|
666
|
+
q.mul_(gain * gain)
|
|
667
|
+
else: # matrix Q
|
|
668
|
+
ell = norm_lower_bound_spd(term1 + term2)
|
|
669
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
670
|
+
err = lr/2/L[i] * (term1 - term2)
|
|
671
|
+
p = q - err @ q # p = q - lr/L[i]/2 * (term1 - term2) @ q
|
|
672
|
+
p = p - p @ err # p = p - lr/L[i]/2 * p @ (term1 - term2)
|
|
673
|
+
q.copy_((p + p.H)/2) # p must be symmetric or hermitian
|
|
674
|
+
|
|
675
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
676
|
+
balance_kron_precond(Q)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def update_precond_kron_newton_quad4p(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
|
|
680
|
+
"""
|
|
681
|
+
Almost the same as function update_precond_kron_newton_quad except that we fit P directly.
|
|
682
|
+
This is the only case that fits P directly (Q here is P). It's vulnerable to numerical errors.
|
|
683
|
+
"""
|
|
684
|
+
Q, L = QL
|
|
685
|
+
exprA, exprGs = exprs
|
|
686
|
+
Ph = exprA(*Q, Hvp + damping*torch.randn_like(Hvp)) # Q actually is P; so only need to apply its factors once.
|
|
687
|
+
|
|
688
|
+
for i, q in enumerate(Q):
|
|
689
|
+
term1 = exprGs[i](Ph, Ph.conj())
|
|
690
|
+
term2 = exprGs[i](V, V.conj())
|
|
691
|
+
if q.dim() < 2: # diagonal or scalar Q
|
|
692
|
+
ell = torch.max(torch.real(term1 + term2))
|
|
693
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
694
|
+
gain = 1 - lr/L[i] * (term1 - term2)
|
|
695
|
+
q.mul_(gain * gain)
|
|
696
|
+
else: # matrix Q
|
|
697
|
+
ell = norm_lower_bound_spd(term1 + term2)
|
|
698
|
+
L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
|
|
699
|
+
err = lr/L[i] * (term1 - term2)
|
|
700
|
+
p = q - err @ q # p = q - lr/L[i] * (term1 - term2) @ q
|
|
701
|
+
p = p - p @ err # p = p - lr/L[i] * p @ (term1 - term2)
|
|
702
|
+
q.copy_((p + p.H)/2) # p must be symmetric or hermitian
|
|
703
|
+
|
|
704
|
+
if torch.rand([]) < balance_prob: # balance factors of Q
|
|
705
|
+
balance_kron_precond(Q)
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
class KronNewton:
|
|
709
|
+
"""
|
|
710
|
+
Implements the Kronecker product Newton-type preconditioner as a class.
|
|
711
|
+
Most of the time, the hyperparameter name says it all. Here are some comments on a few key parameters.
|
|
712
|
+
|
|
713
|
+
1, preconditioner_max_size and preconditioner_max_skew. These two together control the complexity of the preconditioners.
|
|
714
|
+
For example, we are to precondition a 2D gradient with shape 10 x 50.
|
|
715
|
+
With preconditioner_max_size 20, we use a dense preconditioner for the first dim since 10 <= 20 and diagonal preconditioner for the second dim since 50 > 20.
|
|
716
|
+
With preconditioner_max_skew 1.5, we use a dense preconditioner for the first dim since 10/50 <= 1.5 and diagonal preconditioner for the second dim since 50/10 > 1.5.
|
|
717
|
+
|
|
718
|
+
2, grad_clip_max_norm, betaL and damping. These three together help to stabilize the training.
|
|
719
|
+
The grad_clip_max_norm is used to clip the preconditioned gradient to stabilize the optimization as in the classic trust region method.
|
|
720
|
+
Setting damping is used to damp and upper bound the fitted preconditioner such that P < eye/damping.
|
|
721
|
+
For extremely sparse Hess-vector-prod, a large betaL (say 0.999) helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
|
|
722
|
+
|
|
723
|
+
3, exact_hessian_vector_product.
|
|
724
|
+
By setting this flag to False, the finite difference method will be used for Hvp approximation.
|
|
725
|
+
Be cautious with the finite difference method (possible numerical issues; the closure must behave like a pure function).
|
|
726
|
+
|
|
727
|
+
4, Lastly, dQ is for the selection of geometry for preconditioner update. QEQ, QUAD and Q0p5EQ1p5 all are good choices.
|
|
728
|
+
Both lr_params and lr_preconditioner are normalized learning rates.
|
|
729
|
+
Q is initialized to preconditioner_init_scale * eye.
|
|
730
|
+
Always good to check https://arxiv.org/abs/2402.11858 for math details.
|
|
731
|
+
"""
|
|
732
|
+
def __init__(self, params_with_grad, preconditioner_max_size=float("inf"), preconditioner_max_skew=1.0, preconditioner_init_scale:float|None=None,
|
|
733
|
+
lr_params=0.01, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
|
|
734
|
+
grad_clip_max_norm=float("inf"), preconditioner_update_probability=1.0,
|
|
735
|
+
exact_hessian_vector_product=True, dQ="Q0.5EQ1.5"):
|
|
736
|
+
# mutable members
|
|
737
|
+
self.lr_params = lr_params
|
|
738
|
+
self.lr_preconditioner = lr_preconditioner
|
|
739
|
+
self.betaL = betaL # beta for Lipschitz smoothness constant estimation; set to a large value for sparse Hvp
|
|
740
|
+
self.damping = damping # used to damp and upper bound P as P < eye/damping
|
|
741
|
+
self.momentum = momentum if (0<momentum<1) else 0.0
|
|
742
|
+
self.grad_clip_max_norm = grad_clip_max_norm
|
|
743
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
|
744
|
+
# protected members
|
|
745
|
+
self._preconditioner_max_size = preconditioner_max_size
|
|
746
|
+
self._preconditioner_max_skew = preconditioner_max_skew
|
|
747
|
+
params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
|
|
748
|
+
self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
|
|
749
|
+
eps = max([torch.finfo(p.dtype).eps for p in self._params_with_grad])
|
|
750
|
+
self._delta_param_scale = eps ** 0.5
|
|
751
|
+
if preconditioner_init_scale is None:
|
|
752
|
+
self._QLs_exprs = None # initialize on the fly
|
|
753
|
+
print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
754
|
+
else:
|
|
755
|
+
self._QLs_exprs = [init_kron(p.squeeze(), preconditioner_init_scale, preconditioner_max_size, preconditioner_max_skew, dQ) for p in self._params_with_grad]
|
|
756
|
+
self._ms, self._counter_m = None, 0 # momentum buffers and counter
|
|
757
|
+
self._exact_hessian_vector_product = exact_hessian_vector_product
|
|
758
|
+
if not exact_hessian_vector_product:
|
|
759
|
+
print("FYI: Approximate Hvp with finite-difference method. Make sure that: 1) the closure behaves like a pure function; 2) delta param scale is proper.")
|
|
760
|
+
self._dQ = dQ
|
|
761
|
+
if dQ == "QUAD4P": # the only case that fits P directly
|
|
762
|
+
self._update_precond = update_precond_kron_newton_quad4p
|
|
763
|
+
self._precond_grad = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
|
|
764
|
+
assert eps < 1e-6, "Directly fitting P needs at least single precision"
|
|
765
|
+
else:
|
|
766
|
+
self._precond_grad = precond_grad_kron
|
|
767
|
+
if dQ == "QUAD":
|
|
768
|
+
self._update_precond = update_precond_kron_newton_quad
|
|
769
|
+
elif dQ == "QEP":
|
|
770
|
+
self._update_precond = update_precond_kron_newton_qep
|
|
771
|
+
elif dQ == "EQ":
|
|
772
|
+
self._update_precond = update_precond_kron_newton_eq
|
|
773
|
+
elif dQ == "QEQ":
|
|
774
|
+
self._update_precond = update_precond_kron_newton_qeq
|
|
775
|
+
else:
|
|
776
|
+
assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), "Invalid choice for dQ"
|
|
777
|
+
self._update_precond = update_precond_kron_newton_q0p5eq1p5
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
@torch.no_grad()
|
|
781
|
+
def step(self, closure):
|
|
782
|
+
"""
|
|
783
|
+
Performs one step of PSGD with the Kronecker product Newton-type preconditioner.
|
|
784
|
+
"""
|
|
785
|
+
if (torch.rand([]) < self.preconditioner_update_probability) or (self._QLs_exprs is None):
|
|
786
|
+
# evaluates gradients, Hessian-vector product, and updates the preconditioner
|
|
787
|
+
if self._exact_hessian_vector_product:
|
|
788
|
+
with torch.enable_grad():
|
|
789
|
+
closure_returns = closure()
|
|
790
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
791
|
+
grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)
|
|
792
|
+
vs = [torch.randn_like(p) for p in self._params_with_grad]
|
|
793
|
+
Hvs = torch.autograd.grad(grads, self._params_with_grad, vs) # this line also works for complex matrices
|
|
794
|
+
else: # approximate the Hessian-vector product via finite-difference formulae. Use it with cautions.
|
|
795
|
+
with torch.enable_grad():
|
|
796
|
+
closure_returns = closure()
|
|
797
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
798
|
+
grads = torch.autograd.grad(loss, self._params_with_grad)
|
|
799
|
+
|
|
800
|
+
vs = [torch.randn_like(p) for p in self._params_with_grad]
|
|
801
|
+
[p.add_(v, alpha=self._delta_param_scale) for (p, v) in zip(self._params_with_grad, vs)]
|
|
802
|
+
with torch.enable_grad():
|
|
803
|
+
perturbed_returns = closure()
|
|
804
|
+
perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
|
|
805
|
+
perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)
|
|
806
|
+
Hvs = [(perturbed_g - g)/self._delta_param_scale for (perturbed_g, g) in zip(perturbed_grads, grads)]
|
|
807
|
+
[p.sub_(v, alpha=self._delta_param_scale) for (p, v) in zip(self._params_with_grad, vs)] # remove the perturbation
|
|
808
|
+
|
|
809
|
+
if self._QLs_exprs is None: # initialize QLs on the fly if it is None
|
|
810
|
+
scale = (sum([torch.sum(torch.abs(v)**2) for v in vs])/sum([v.numel() for v in vs])) ** (1/4) # (mean(|v|^2))^(1/4)
|
|
811
|
+
scale = scale * (max([torch.mean((torch.abs(h))**4) for h in Hvs]) + self.damping**4) ** (-1/8) # (mean(|v|^2))^(1/4) * (mean(|h|^4))^(-1/8)
|
|
812
|
+
self._QLs_exprs = [init_kron(h.squeeze(), scale, self._preconditioner_max_size, self._preconditioner_max_skew, self._dQ) for h in Hvs]
|
|
813
|
+
# update preconditioner
|
|
814
|
+
[self._update_precond(*QL_exprs, v.squeeze(), h.squeeze(), lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
|
|
815
|
+
for (QL_exprs, v, h) in zip(self._QLs_exprs, vs, Hvs)]
|
|
816
|
+
else: # only evaluate the gradients
|
|
817
|
+
with torch.enable_grad():
|
|
818
|
+
closure_returns = closure()
|
|
819
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
820
|
+
grads = torch.autograd.grad(loss, self._params_with_grad)
|
|
821
|
+
|
|
822
|
+
grads = [g.squeeze() for g in grads]
|
|
823
|
+
if self.momentum > 0: # precondition the momentum
|
|
824
|
+
beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
|
|
825
|
+
self._counter_m += 1
|
|
826
|
+
if self._ms is None:
|
|
827
|
+
self._ms = [torch.zeros_like(g) for g in grads]
|
|
828
|
+
|
|
829
|
+
[m.mul_(beta).add_(g, alpha=1 - beta) for (m, g) in zip(self._ms, grads)]
|
|
830
|
+
pre_grads = [self._precond_grad(*QL_exprs, m) for (QL_exprs, m) in zip(self._QLs_exprs, self._ms)]
|
|
831
|
+
else: # precondition the gradient
|
|
832
|
+
self._ms, self._counter_m = None, 0 # clear the buffer and counter when momentum is set to zero
|
|
833
|
+
pre_grads = [self._precond_grad(*QL_exprs, g) for (QL_exprs, g) in zip(self._QLs_exprs, grads)]
|
|
834
|
+
|
|
835
|
+
lr = self.lr_params
|
|
836
|
+
if self.grad_clip_max_norm < float("inf"):
|
|
837
|
+
grad_norm = torch.sqrt(torch.real(sum([torch.sum(g*g.conj()) for g in pre_grads])))
|
|
838
|
+
if grad_norm > self.grad_clip_max_norm:
|
|
839
|
+
lr = lr * self.grad_clip_max_norm / grad_norm
|
|
840
|
+
|
|
841
|
+
# Update the parameters.
|
|
842
|
+
[param.subtract_(lr*g.view_as(param)) for (param, g) in zip(self._params_with_grad, pre_grads)]
|
|
843
|
+
|
|
844
|
+
# return whatever closure returns
|
|
845
|
+
return closure_returns
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
############# End of PSGD Kronecker product preconditioners #############
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
############# Begin of PSGD LRA (low rank approximation) preconditioners #############
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def IpUVtmatvec(U, V, x):
|
|
855
|
+
"""
|
|
856
|
+
Returns (I + U*V')*x. All variables are either matrices or column vectors.
|
|
857
|
+
"""
|
|
858
|
+
return x + U.mm(V.t().mm(x))
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
def update_precond_lra(UVd, Luvd, v, h, lr=0.1, betaL=0.9):
|
|
862
|
+
"""
|
|
863
|
+
The raw function for updating the LRA preconditioner Q = (I + U*V')*diag(d) with pair (v, h),
|
|
864
|
+
where h can a Hvp associated with v, or a gradient/momentum independent of v.
|
|
865
|
+
State variables (U, V, d) and their Lipschitz smoothness constant estimates (Lu, Lv, Ld) are updated inplace.
|
|
866
|
+
Damping logic is not implemented here.
|
|
867
|
+
Note that U, V, d, v, and h all are either matrices or column vectors.
|
|
868
|
+
"""
|
|
869
|
+
U, V, d = UVd
|
|
870
|
+
Lu, Lv, Ld = Luvd
|
|
871
|
+
|
|
872
|
+
# Approximately balancing U and V such that U^T U = V^T V (exact balancing needs three EVDs)
|
|
873
|
+
UtU, VtV = U.t() @ U, V.t() @ V
|
|
874
|
+
trUtU, trVtV = torch.sum(UtU.diagonal()), torch.sum(VtV.diagonal())
|
|
875
|
+
rho = (trUtU/trVtV) ** (1/4) # will scale U and V as U <-- U/rho and V <-- V*rho
|
|
876
|
+
rho2 = rho * rho
|
|
877
|
+
E = 0.1 * (UtU/rho2 - VtV*rho2)/(trUtU/rho2 + trVtV*rho2) # errors after scaling U and V
|
|
878
|
+
E2 = 0.5 * E @ E # using this E2 term to make (I - E + E^2/2)(I + E + E^2/2) = (I + E^2/2)^2 - E^2 = I + E^4/4
|
|
879
|
+
U.div_(rho), V.mul_(rho) # scale U and V to have ||U||_F = ||V||_F
|
|
880
|
+
U.sub_(U @ (E - E2)), V.add_(V @ (E + E2)) # rotate (as tr(E)=0) U and V to approach U^TU = V^TV
|
|
881
|
+
|
|
882
|
+
Qh = IpUVtmatvec(U, V, d * h)
|
|
883
|
+
Ph = d*IpUVtmatvec(V, U, Qh)
|
|
884
|
+
|
|
885
|
+
IpVtU = V.t().mm(U)
|
|
886
|
+
IpVtU.diagonal().add_(1) # avoid forming matrix I explicitly
|
|
887
|
+
invQtv = v/d
|
|
888
|
+
LU, pivots, _ = torch.linalg.lu_factor_ex(lift2single(IpVtU))
|
|
889
|
+
invQtv = invQtv - V.mm(torch.linalg.lu_solve(LU, pivots, lift2single(U.t().mm(invQtv)), adjoint=True).to(V.dtype))
|
|
890
|
+
invPv = invQtv - U.mm(torch.linalg.lu_solve(LU, pivots, lift2single(V.t().mm(invQtv))).to(U.dtype))
|
|
891
|
+
invPv = invPv/d
|
|
892
|
+
|
|
893
|
+
# update d
|
|
894
|
+
Phh, vinvPv = Ph*h, v*invPv
|
|
895
|
+
ell = torch.max(torch.abs(Phh)) + torch.max(torch.abs(vinvPv))
|
|
896
|
+
Ld.copy_(torch.max(betaL*Ld + (1 - betaL)*ell, ell))
|
|
897
|
+
d.sub_(lr/Ld*(Phh - vinvPv)*d) # d.mul_(1 - lr/Ld*(Phh - vinvPv)): larger roundoff errors, unstable with bfloat16 and lr<<1
|
|
898
|
+
|
|
899
|
+
a, b = Qh, invQtv
|
|
900
|
+
if torch.rand([]) < 0.5: # only update U
|
|
901
|
+
atV = a.t().mm(V)
|
|
902
|
+
btV = b.t().mm(V)
|
|
903
|
+
atVVt = atV.mm(V.t())
|
|
904
|
+
btVVt = btV.mm(V.t())
|
|
905
|
+
ell = (torch.linalg.vector_norm(a)*torch.linalg.vector_norm(atVVt) +
|
|
906
|
+
torch.linalg.vector_norm(b)*torch.linalg.vector_norm(btVVt))
|
|
907
|
+
Lu.copy_(torch.max(betaL*Lu + (1 - betaL)*ell, ell))
|
|
908
|
+
U.sub_(lr/Lu * ( a.mm(atV.mm(IpVtU)) - b.mm(btV.mm(IpVtU)) ))
|
|
909
|
+
else: # only udate V
|
|
910
|
+
atU = a.t().mm(U)
|
|
911
|
+
btU = b.t().mm(U)
|
|
912
|
+
UUta = U.mm(atU.t())
|
|
913
|
+
UUtb = U.mm(btU.t())
|
|
914
|
+
ell = (torch.linalg.vector_norm(a)*torch.linalg.vector_norm(UUta) +
|
|
915
|
+
torch.linalg.vector_norm(b)*torch.linalg.vector_norm(UUtb))
|
|
916
|
+
Lv.copy_(torch.max(betaL*Lv + (1 - betaL)*ell, ell))
|
|
917
|
+
V.sub_(lr/Lv * ( (a + V.mm(atU.t())).mm(atU) - (b + V.mm(btU.t())).mm(btU) ))
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
def precond_grad_lra(UVd, g):
|
|
921
|
+
"""
|
|
922
|
+
Precondition gradient g with Q = (I + U*V')*diag(d).
|
|
923
|
+
All variables here are either matrices or column vectors.
|
|
924
|
+
"""
|
|
925
|
+
U, V, d = UVd
|
|
926
|
+
g = IpUVtmatvec(U, V, d * g)
|
|
927
|
+
g = d * IpUVtmatvec(V, U, g)
|
|
928
|
+
return g
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def update_precond_lra_whiten(UVd, Luvd, g, lr=0.1, betaL=0.9, damping=1e-9):
|
|
932
|
+
"""
|
|
933
|
+
Update the LRA whiten preconditioner.
|
|
934
|
+
"""
|
|
935
|
+
v = torch.randn_like(g)
|
|
936
|
+
update_precond_lra(UVd, Luvd, v, g + damping*v, lr=lr, betaL=betaL)
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
class LRAWhiten:
|
|
940
|
+
"""
|
|
941
|
+
Implements the PSGD LRA gradient/momentum whitening preconditioner as a class.
|
|
942
|
+
Most of the time, the hyperparameter name says it all. Here are some comments on a few key parameters.
|
|
943
|
+
|
|
944
|
+
1, rank_of_approximation.
|
|
945
|
+
Preconditioner Q has a diagonal part and a low rank part, whose rank is decided by this setting.
|
|
946
|
+
Rank 0 reduces Q to a diagonal preconditioner.
|
|
947
|
+
|
|
948
|
+
2, grad_clip_max_amp, betaL and damping. These three together help to stabilize the training.
|
|
949
|
+
PSGD here tries to normalize the gradients to unit amplitude. This can be problematic when gradients approach zeros.
|
|
950
|
+
The most effective way is to clip the preconditioned gradients when their amplitudes exceed grad_clip_max_amp, say 1.0.
|
|
951
|
+
Another way is to damp and upper bound the fitted preconditioner as P < eye/damping.
|
|
952
|
+
For extremely sparse gradient, increasing betaL (say to 0.999) also helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
|
|
953
|
+
|
|
954
|
+
3, Lastly, Q is initialized to preconditioner_init_scale * eye.
|
|
955
|
+
Boolean setting whiten_grad decides to whiten whether the gradient or momentum.
|
|
956
|
+
Always good to check https://arxiv.org/abs/2402.11858 for math details.
|
|
957
|
+
"""
|
|
958
|
+
def __init__(self, params_with_grad, rank_of_approximation:int=10, preconditioner_init_scale:float|None=None,
|
|
959
|
+
lr_params=0.001, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
|
|
960
|
+
grad_clip_max_amp=float("inf"), preconditioner_update_probability=1.0, whiten_grad=True):
|
|
961
|
+
# mutable members
|
|
962
|
+
self.lr_params = lr_params
|
|
963
|
+
self.lr_preconditioner = lr_preconditioner
|
|
964
|
+
self.betaL = betaL # set to a large betaL for sparse gradients
|
|
965
|
+
self.damping = damping # to damp and upper bound P as P < eye/damping
|
|
966
|
+
self.momentum = momentum if (0<momentum<1) else 0.0
|
|
967
|
+
self.grad_clip_max_amp = grad_clip_max_amp
|
|
968
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
|
969
|
+
# protected members
|
|
970
|
+
params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
|
|
971
|
+
self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
|
|
972
|
+
dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device
|
|
973
|
+
self._param_sizes = [torch.numel(param) for param in self._params_with_grad]
|
|
974
|
+
self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)
|
|
975
|
+
num_params = self._param_cumsizes[-1]
|
|
976
|
+
assert 0 <= rank_of_approximation < num_params, "Rank r should be in range [0, number of total parameters)"
|
|
977
|
+
self._UVd = [] # saves U, V and d
|
|
978
|
+
self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # U
|
|
979
|
+
self._UVd[0] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[0])
|
|
980
|
+
self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # V
|
|
981
|
+
self._UVd[1] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[1])
|
|
982
|
+
if preconditioner_init_scale is None:
|
|
983
|
+
print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
984
|
+
else:
|
|
985
|
+
self._UVd.append(torch.ones(num_params, 1, dtype=dtype, device=device) * preconditioner_init_scale)
|
|
986
|
+
self._Luvd = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
|
|
987
|
+
self._m, self._counter_m = None, 0 # momentum buffer and counter
|
|
988
|
+
self._whiten_grad = whiten_grad
|
|
989
|
+
if (not whiten_grad):
|
|
990
|
+
assert self.momentum > 0, "Cannot whiten momentum if the momentum setting is zero."
|
|
991
|
+
print(f"Recommend to reduce lr_params by {int(((1 + momentum)/(1 - momentum))**0.5)} times")
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
@torch.no_grad()
|
|
995
|
+
def step(self, closure):
|
|
996
|
+
"""
|
|
997
|
+
Performs one step of the PSGD LRA gradient/momentum whitening optimizer.
|
|
998
|
+
"""
|
|
999
|
+
with torch.enable_grad():
|
|
1000
|
+
closure_returns = closure()
|
|
1001
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
1002
|
+
grads = torch.autograd.grad(loss, self._params_with_grad)
|
|
1003
|
+
|
|
1004
|
+
# cat grads
|
|
1005
|
+
grad = torch.cat([torch.reshape(g, [-1, 1]) for g in grads]) # column vector
|
|
1006
|
+
|
|
1007
|
+
if len(self._UVd) < 3: # initialize d on the fly
|
|
1008
|
+
self._UVd.append((torch.mean(grad**4) + self.damping**4)**(-1/8) * torch.ones_like(grad))
|
|
1009
|
+
|
|
1010
|
+
if self.momentum > 0:
|
|
1011
|
+
beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
|
|
1012
|
+
self._counter_m += 1
|
|
1013
|
+
if self._m is None:
|
|
1014
|
+
self._m = torch.zeros_like(grad)
|
|
1015
|
+
|
|
1016
|
+
self._m.mul_(beta).add_(grad, alpha=1 - beta)
|
|
1017
|
+
else: # clear the momentum buffer and counter when momentum is set to zero
|
|
1018
|
+
self._m, self._counter_m = None, 0
|
|
1019
|
+
|
|
1020
|
+
if torch.rand([]) < self.preconditioner_update_probability: # update preconditioner
|
|
1021
|
+
if self._whiten_grad: # whitens gradient
|
|
1022
|
+
update_precond_lra_whiten(self._UVd, self._Luvd, grad, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
|
|
1023
|
+
else: # whitens momentum
|
|
1024
|
+
update_precond_lra_whiten(self._UVd, self._Luvd, self._m, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
|
|
1025
|
+
|
|
1026
|
+
if self.momentum > 0: # precondition momentum
|
|
1027
|
+
pre_grad = precond_grad_lra(self._UVd, self._m)
|
|
1028
|
+
else: # precondition gradient
|
|
1029
|
+
pre_grad = precond_grad_lra(self._UVd, grad)
|
|
1030
|
+
|
|
1031
|
+
lr = self.lr_params
|
|
1032
|
+
if self.grad_clip_max_amp < float("inf"): # clip preconditioned gradient
|
|
1033
|
+
amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
|
|
1034
|
+
if amp > self.grad_clip_max_amp:
|
|
1035
|
+
lr = lr * self.grad_clip_max_amp/amp
|
|
1036
|
+
|
|
1037
|
+
# update the parameters
|
|
1038
|
+
[param.subtract_(lr * pre_grad[j - i:j].view_as(param))
|
|
1039
|
+
for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]
|
|
1040
|
+
|
|
1041
|
+
# return whatever closure returns
|
|
1042
|
+
return closure_returns
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
def update_precond_lra_newton(UVd, Luvd, v, h, lr=0.1, betaL=0.9, damping=1e-9):
|
|
1046
|
+
"""
|
|
1047
|
+
Update the LRA Newton preconditioner.
|
|
1048
|
+
"""
|
|
1049
|
+
update_precond_lra(UVd, Luvd, v, h + damping*torch.randn_like(h), lr=lr, betaL=betaL)
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
class LRANewton:
|
|
1053
|
+
"""
|
|
1054
|
+
Implements the PSGD LRA Newton-type preconditioner as a class.
|
|
1055
|
+
Most of the time, the hyperparameter name says it all. Here are some comments on a few key parameters.
|
|
1056
|
+
|
|
1057
|
+
1, rank_of_approximation.
|
|
1058
|
+
Preconditioner Q has a diagonal part and a low rank part, whose rank is decided by this setting.
|
|
1059
|
+
Rank 0 reduces Q to a diagonal preconditioner.
|
|
1060
|
+
|
|
1061
|
+
2, grad_clip_max_norm, betaL and damping. These three together help to stabilize the training.
|
|
1062
|
+
The grad_clip_max_norm is used to clip the preconditioned gradient to stabilize the optimization as in the classic trust region method.
|
|
1063
|
+
Setting damping is used to damp and upper bound the preconditioner as P < eye/damping.
|
|
1064
|
+
For extremely sparse hess-vector-prods, a large betaL (say 0.999) helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
|
|
1065
|
+
|
|
1066
|
+
3, exact_hessian_vector_product.
|
|
1067
|
+
By setting this flag to False, the finite difference method will be used for Hvp approximation.
|
|
1068
|
+
Be cautious with the finite difference method (possible numerical issues; the closure must behave like a pure function).
|
|
1069
|
+
|
|
1070
|
+
4, Lastly, Q is initialized to preconditioner_init_scale * eye.
|
|
1071
|
+
Both lr_params and lr_preconditioner are normalized learning rates.
|
|
1072
|
+
Always good to check https://arxiv.org/abs/2402.11858 for math details.
|
|
1073
|
+
"""
|
|
1074
|
+
def __init__(self, params_with_grad, rank_of_approximation:int=10, preconditioner_init_scale:float|None=None,
|
|
1075
|
+
lr_params=0.01, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
|
|
1076
|
+
grad_clip_max_norm=float("inf"), preconditioner_update_probability=1.0,
|
|
1077
|
+
exact_hessian_vector_product=True):
|
|
1078
|
+
# mutable members
|
|
1079
|
+
self.lr_params = lr_params
|
|
1080
|
+
self.lr_preconditioner = lr_preconditioner
|
|
1081
|
+
self.betaL = betaL # set to a large betaL for sparse Hvp
|
|
1082
|
+
self.damping = damping # to damp and upper bound the preconditioner as P < eye/damping
|
|
1083
|
+
self.momentum = momentum if (0<momentum<1) else 0.0
|
|
1084
|
+
self.grad_clip_max_norm = grad_clip_max_norm
|
|
1085
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
|
1086
|
+
# protected members
|
|
1087
|
+
params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
|
|
1088
|
+
self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
|
|
1089
|
+
dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device
|
|
1090
|
+
self._delta_param_scale = torch.finfo(dtype).eps**0.5
|
|
1091
|
+
self._param_sizes = [torch.numel(param) for param in self._params_with_grad]
|
|
1092
|
+
self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)
|
|
1093
|
+
num_params = self._param_cumsizes[-1]
|
|
1094
|
+
assert 0 <= rank_of_approximation < num_params, "Rank r should be in range [0, number of total parameters)"
|
|
1095
|
+
self._UVd = [] # saves U, V and d
|
|
1096
|
+
self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # U
|
|
1097
|
+
self._UVd[0] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[0])
|
|
1098
|
+
self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # V
|
|
1099
|
+
self._UVd[1] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[1])
|
|
1100
|
+
if preconditioner_init_scale is None:
|
|
1101
|
+
print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
1102
|
+
else:
|
|
1103
|
+
self._UVd.append(torch.ones(num_params, 1, dtype=dtype, device=device) * preconditioner_init_scale)
|
|
1104
|
+
self._Luvd = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
|
|
1105
|
+
self._m, self._counter_m = None, 0 # momentum buffer and counter
|
|
1106
|
+
self._exact_hessian_vector_product = exact_hessian_vector_product
|
|
1107
|
+
if not exact_hessian_vector_product:
|
|
1108
|
+
print("FYI: Approximate Hvp with finite-difference method. Make sure that: 1) the closure behaves like a pure function; 2) delta param scale is proper.")
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
@torch.no_grad()
|
|
1112
|
+
def step(self, closure):
|
|
1113
|
+
"""
|
|
1114
|
+
Performs one step of the PSGD LRA Newton optimizer.
|
|
1115
|
+
"""
|
|
1116
|
+
if (torch.rand([]) < self.preconditioner_update_probability) or (len(self._UVd) < 3):
|
|
1117
|
+
# evaluates gradients, Hessian-vector product, and updates the preconditioner
|
|
1118
|
+
if self._exact_hessian_vector_product:
|
|
1119
|
+
with torch.enable_grad():
|
|
1120
|
+
closure_returns = closure()
|
|
1121
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
1122
|
+
grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)
|
|
1123
|
+
vs = [torch.randn_like(param) for param in self._params_with_grad]
|
|
1124
|
+
Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)
|
|
1125
|
+
else: # approximate Hessian-vector product via finite-difference formulae. Use it with cautions.
|
|
1126
|
+
with torch.enable_grad():
|
|
1127
|
+
closure_returns = closure()
|
|
1128
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
1129
|
+
grads = torch.autograd.grad(loss, self._params_with_grad)
|
|
1130
|
+
|
|
1131
|
+
vs = [torch.randn_like(param) for param in self._params_with_grad]
|
|
1132
|
+
[param.add_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
|
|
1133
|
+
with torch.enable_grad():
|
|
1134
|
+
perturbed_returns = closure()
|
|
1135
|
+
perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
|
|
1136
|
+
perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)
|
|
1137
|
+
Hvs = [(perturbed_g - g)/self._delta_param_scale for (perturbed_g, g) in zip(perturbed_grads, grads)]
|
|
1138
|
+
[param.sub_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
|
|
1139
|
+
|
|
1140
|
+
v = torch.cat([torch.reshape(v, [-1, 1]) for v in vs]) # column vector
|
|
1141
|
+
h = torch.cat([torch.reshape(h, [-1, 1]) for h in Hvs]) # column vector
|
|
1142
|
+
if len(self._UVd) < 3: # init d if it's not in the UVd list
|
|
1143
|
+
self._UVd.append((torch.mean(v*v))**(1/4) * (torch.mean(h**4) + self.damping**4)**(-1/8) * torch.ones_like(v))
|
|
1144
|
+
|
|
1145
|
+
# update preconditioner
|
|
1146
|
+
update_precond_lra_newton(self._UVd, self._Luvd, v, h, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
|
|
1147
|
+
else: # only evaluates the gradients
|
|
1148
|
+
with torch.enable_grad():
|
|
1149
|
+
closure_returns = closure()
|
|
1150
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
1151
|
+
grads = torch.autograd.grad(loss, self._params_with_grad)
|
|
1152
|
+
|
|
1153
|
+
# cat grads
|
|
1154
|
+
grad = torch.cat([torch.reshape(g, [-1, 1]) for g in grads]) # column vector
|
|
1155
|
+
|
|
1156
|
+
if self.momentum > 0: # precondition momentum
|
|
1157
|
+
beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
|
|
1158
|
+
self._counter_m += 1
|
|
1159
|
+
if self._m is None:
|
|
1160
|
+
self._m = torch.zeros_like(grad)
|
|
1161
|
+
|
|
1162
|
+
self._m.mul_(beta).add_(grad, alpha=1 - beta)
|
|
1163
|
+
pre_grad = precond_grad_lra(self._UVd, self._m)
|
|
1164
|
+
else: # precondition gradient
|
|
1165
|
+
self._m, self._counter_m = None, 0 # clear the buffer and counter when momentum is set to zero
|
|
1166
|
+
pre_grad = precond_grad_lra(self._UVd, grad)
|
|
1167
|
+
|
|
1168
|
+
lr = self.lr_params
|
|
1169
|
+
if self.grad_clip_max_norm < float("inf"):
|
|
1170
|
+
grad_norm = torch.linalg.vector_norm(pre_grad)
|
|
1171
|
+
if grad_norm > self.grad_clip_max_norm:
|
|
1172
|
+
lr = lr * self.grad_clip_max_norm / grad_norm
|
|
1173
|
+
|
|
1174
|
+
# update the parameters
|
|
1175
|
+
[param.subtract_(lr * pre_grad[j - i:j].view_as(param))
|
|
1176
|
+
for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]
|
|
1177
|
+
|
|
1178
|
+
# return whatever closure returns
|
|
1179
|
+
return closure_returns
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
############# End of PSGD LRA preconditioners #############
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
############# Begin of PSGD dense matrix Newton-type preconditioner #############
|
|
1186
|
+
|
|
1187
|
+
|
|
1188
|
+
def update_precond_dense_eq(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
|
|
1189
|
+
"""
|
|
1190
|
+
Update dense matrix Newton-type preconditioner Q with local coordinate dQ = mathcal{E} * Q.
|
|
1191
|
+
"""
|
|
1192
|
+
a = Q.mm(h + damping*torch.randn_like(h))
|
|
1193
|
+
b = torch.linalg.solve_triangular(lift2single(Q.t()), lift2single(v), upper=False).to(v.dtype)
|
|
1194
|
+
ell = torch.sum(a*a + b*b)
|
|
1195
|
+
L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
|
|
1196
|
+
Q.sub_(lr/L * torch.triu(a.mm(a.t()) - b.mm(b.t())) @ Q)
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
def update_precond_dense_qep(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
|
|
1200
|
+
"""
|
|
1201
|
+
Update dense matrix Newton-type preconditioner Q with local coordinate dQ = Q * mathcal{E} * P.
|
|
1202
|
+
"""
|
|
1203
|
+
a = Q @ (Q.T @ (Q @ (h + damping*torch.randn_like(h))))
|
|
1204
|
+
b = Q @ v
|
|
1205
|
+
ell = torch.sum(a*a + b*b)
|
|
1206
|
+
L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
|
|
1207
|
+
Q.sub_(lr/L * (a @ (a.T @ Q) - b @ (b.T @ Q)))
|
|
1208
|
+
|
|
1209
|
+
|
|
1210
|
+
def update_precond_dense_qeq(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
|
|
1211
|
+
"""
|
|
1212
|
+
Update dense matrix Newton-type preconditioner Q with local coordinate dQ = Q * mathcal{E} * Q.
|
|
1213
|
+
"""
|
|
1214
|
+
a = Q.T @ (Q @ (h + damping*torch.randn_like(h)))
|
|
1215
|
+
ell = torch.sum(a*a + v*v)
|
|
1216
|
+
L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
|
|
1217
|
+
Q.sub_(lr/L * ((Q @ a) @ a.T - (Q @ v) @ v.T))
|
|
1218
|
+
|
|
1219
|
+
|
|
1220
|
+
def update_precond_dense_q0p5eq1p5(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
|
|
1221
|
+
"""
|
|
1222
|
+
Update dense matrix Newton-type preconditioner Q with local coordinate dQ = Q^0.5 * mathcal{E} * Q^1.5.
|
|
1223
|
+
"""
|
|
1224
|
+
a = Q.T @ (Q @ (h + damping*torch.randn_like(h)))
|
|
1225
|
+
ell = torch.sum(a*a + v*v)
|
|
1226
|
+
L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
|
|
1227
|
+
Q.sub_(lr/L * (a @ (a.T @ Q) - v @ (v.T @ Q)))
|
|
1228
|
+
procrustes_step(Q)
|
|
1229
|
+
|
|
1230
|
+
|
|
1231
|
+
def update_precond_dense_quad(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
|
|
1232
|
+
"""
|
|
1233
|
+
Update dense matrix Newton-type preconditioner Q with a quadratic form for dQ.
|
|
1234
|
+
"""
|
|
1235
|
+
a = Q @ (Q @ (h + damping*torch.randn_like(h))) # Q is symmetric here
|
|
1236
|
+
ell = torch.sum(a*a + v*v)
|
|
1237
|
+
L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
|
|
1238
|
+
p = Q - lr/2/L * (a @ (a.T @ Q) - v @ (v.T @ Q))
|
|
1239
|
+
p = p - lr/2/L * ((p @ a) @ a.T - (p @ v) @ v.T)
|
|
1240
|
+
Q.copy_((p + p.T)/2)
|
|
1241
|
+
|
|
1242
|
+
|
|
1243
|
+
def update_precond_dense_quad4p(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
|
|
1244
|
+
"""
|
|
1245
|
+
The only case that fits P directly.
|
|
1246
|
+
"""
|
|
1247
|
+
a = Q @ (h + damping*torch.randn_like(h)) # Q actually is P; so just apply it once.
|
|
1248
|
+
ell = torch.sum(a*a + v*v)
|
|
1249
|
+
L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
|
|
1250
|
+
p = Q - lr/L * (a @ (a.T @ Q) - v @ (v.T @ Q))
|
|
1251
|
+
p = p - lr/L * ((p @ a) @ a.T - (p @ v) @ v.T)
|
|
1252
|
+
Q.copy_((p + p.T)/2)
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
class DenseNewton:
|
|
1256
|
+
"""
|
|
1257
|
+
Implements the PSGD dense matrix Newton-type preconditioner as a class.
|
|
1258
|
+
Be extra cautious when using the finite difference method for Hvp approximation (the closure must behave like a pure function).
|
|
1259
|
+
It's mainly for illustrating how PSGD works due to its simplicity.
|
|
1260
|
+
It's also a good alternative to the BFGS like quasi-Newton methods as no line search is required.
|
|
1261
|
+
"""
|
|
1262
|
+
def __init__(self, params_with_grad, preconditioner_init_scale:float|None=None,
|
|
1263
|
+
lr_params=0.01, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
|
|
1264
|
+
grad_clip_max_norm=float("inf"), preconditioner_update_probability=1.0,
|
|
1265
|
+
exact_hessian_vector_product=True, dQ="Q0.5EQ1.5"):
|
|
1266
|
+
# mutable members
|
|
1267
|
+
self.lr_params = lr_params
|
|
1268
|
+
self.lr_preconditioner = lr_preconditioner
|
|
1269
|
+
self.betaL = betaL # set to a large betaL for sparse Hvp
|
|
1270
|
+
self.damping = damping # to damp and upper bound the preconditioner as P < eye/damping
|
|
1271
|
+
self.momentum = momentum if (0<momentum<1) else 0.0
|
|
1272
|
+
self.grad_clip_max_norm = grad_clip_max_norm
|
|
1273
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
|
1274
|
+
# protected members
|
|
1275
|
+
params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
|
|
1276
|
+
self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
|
|
1277
|
+
dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device
|
|
1278
|
+
self._delta_param_scale = torch.finfo(dtype).eps ** 0.5
|
|
1279
|
+
self._param_sizes = [torch.numel(param) for param in self._params_with_grad]
|
|
1280
|
+
self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)
|
|
1281
|
+
num_params = self._param_cumsizes[-1]
|
|
1282
|
+
if preconditioner_init_scale is None: # initialize Q on the fly
|
|
1283
|
+
self._Q = None
|
|
1284
|
+
else:
|
|
1285
|
+
if dQ == "QUAD4P": # Q actually is P
|
|
1286
|
+
preconditioner_init_scale *= preconditioner_init_scale
|
|
1287
|
+
self._Q = torch.eye(num_params, dtype=dtype, device=device) * preconditioner_init_scale
|
|
1288
|
+
self._L = lift2single(torch.zeros([], dtype=dtype, device=device)) # Lipschitz smoothness constant estimation for the psgd criterion
|
|
1289
|
+
self._m, self._counter_m = None, 0 # buffer and counter for momentum
|
|
1290
|
+
self._exact_hessian_vector_product = exact_hessian_vector_product
|
|
1291
|
+
if not exact_hessian_vector_product:
|
|
1292
|
+
print("FYI: Approximate Hvp with finite-difference method. Make sure that: 1) the closure behaves like a pure function; 2) delta param scale is proper.")
|
|
1293
|
+
self._dQ = dQ
|
|
1294
|
+
if dQ == "QUAD4P": # the only case that we fit P directly
|
|
1295
|
+
self._update_precond = update_precond_dense_quad4p
|
|
1296
|
+
self._precond_grad = lambda Q, g: Q @ g
|
|
1297
|
+
assert torch.finfo(dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
|
|
1298
|
+
elif dQ == "QUAD":
|
|
1299
|
+
self._update_precond = update_precond_dense_quad
|
|
1300
|
+
self._precond_grad = lambda Q, g: Q @ (Q @ g) # Q is symmetric; just save one transpose
|
|
1301
|
+
else:
|
|
1302
|
+
self._precond_grad = lambda Q, g: Q.T @ (Q @ g)
|
|
1303
|
+
if dQ == "QEP":
|
|
1304
|
+
self._update_precond = update_precond_dense_qep
|
|
1305
|
+
elif dQ == "EQ":
|
|
1306
|
+
self._update_precond = update_precond_dense_eq
|
|
1307
|
+
elif dQ == "QEQ":
|
|
1308
|
+
self._update_precond = update_precond_dense_qeq
|
|
1309
|
+
else:
|
|
1310
|
+
assert (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"), "Invalid choice for dQ"
|
|
1311
|
+
self._update_precond = update_precond_dense_q0p5eq1p5
|
|
1312
|
+
|
|
1313
|
+
|
|
1314
|
+
@torch.no_grad()
|
|
1315
|
+
def step(self, closure):
|
|
1316
|
+
"""
|
|
1317
|
+
Performs one step of PSGD with the dense matrix Newton-type preconditioner.
|
|
1318
|
+
"""
|
|
1319
|
+
if (torch.rand([]) < self.preconditioner_update_probability) or (self._Q is None):
|
|
1320
|
+
# evaluates gradients, Hessian-vector product, and updates the preconditioner
|
|
1321
|
+
if self._exact_hessian_vector_product: # exact Hessian-vector product
|
|
1322
|
+
with torch.enable_grad():
|
|
1323
|
+
closure_returns = closure()
|
|
1324
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
1325
|
+
grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)
|
|
1326
|
+
vs = [torch.randn_like(param) for param in self._params_with_grad]
|
|
1327
|
+
Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)
|
|
1328
|
+
else: # approximate Hessian-vector product via finite-difference formulae. Use it with cautions.
|
|
1329
|
+
with torch.enable_grad():
|
|
1330
|
+
closure_returns = closure()
|
|
1331
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
1332
|
+
grads = torch.autograd.grad(loss, self._params_with_grad)
|
|
1333
|
+
|
|
1334
|
+
vs = [torch.randn_like(param) for param in self._params_with_grad]
|
|
1335
|
+
[param.add_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
|
|
1336
|
+
with torch.enable_grad():
|
|
1337
|
+
perturbed_returns = closure()
|
|
1338
|
+
perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
|
|
1339
|
+
perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)
|
|
1340
|
+
Hvs = [(perturbed_g - g)/self._delta_param_scale for (perturbed_g, g) in zip(perturbed_grads, grads)]
|
|
1341
|
+
[param.sub_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
|
|
1342
|
+
|
|
1343
|
+
v = torch.cat([torch.reshape(v, [-1, 1]) for v in vs])
|
|
1344
|
+
h = torch.cat([torch.reshape(h, [-1, 1]) for h in Hvs])
|
|
1345
|
+
if self._Q is None: # initialize Q on the fly if it is None
|
|
1346
|
+
scale = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + self.damping**4)**(-1/8)
|
|
1347
|
+
if self._dQ == "QUAD4P": # Q actually is P in this case
|
|
1348
|
+
scale *= scale
|
|
1349
|
+
self._Q = torch.eye(len(v), dtype=v.dtype, device=v.device) * scale
|
|
1350
|
+
|
|
1351
|
+
# update preconditioner
|
|
1352
|
+
self._update_precond(self._Q, self._L, v, h, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
|
|
1353
|
+
else: # only evaluates the gradients
|
|
1354
|
+
with torch.enable_grad():
|
|
1355
|
+
closure_returns = closure()
|
|
1356
|
+
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
|
|
1357
|
+
grads = torch.autograd.grad(loss, self._params_with_grad)
|
|
1358
|
+
|
|
1359
|
+
# cat grads
|
|
1360
|
+
grad = torch.cat([torch.reshape(g, [-1, 1]) for g in grads])
|
|
1361
|
+
|
|
1362
|
+
if self.momentum > 0: # precondition momentum
|
|
1363
|
+
beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
|
|
1364
|
+
self._counter_m += 1
|
|
1365
|
+
if self._m is None:
|
|
1366
|
+
self._m = torch.zeros_like(grad)
|
|
1367
|
+
|
|
1368
|
+
self._m.mul_(beta).add_(grad, alpha=1 - beta)
|
|
1369
|
+
pre_grad = self._precond_grad(self._Q, self._m)
|
|
1370
|
+
else:
|
|
1371
|
+
self._m, self._counter_m = None, 0 # clear the buffer and counter when momentum is set to zero
|
|
1372
|
+
pre_grad = self._precond_grad(self._Q, grad)
|
|
1373
|
+
|
|
1374
|
+
lr = self.lr_params
|
|
1375
|
+
if self.grad_clip_max_norm < float("inf"):
|
|
1376
|
+
grad_norm = torch.linalg.vector_norm(pre_grad)
|
|
1377
|
+
if grad_norm > self.grad_clip_max_norm:
|
|
1378
|
+
lr = lr * self.grad_clip_max_norm / grad_norm
|
|
1379
|
+
|
|
1380
|
+
# update the parameters
|
|
1381
|
+
[param.subtract_(lr * pre_grad[j - i:j].view_as(param))
|
|
1382
|
+
for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]
|
|
1383
|
+
|
|
1384
|
+
# return whatever closure returns
|
|
1385
|
+
return closure_returns
|
|
1386
|
+
|
|
1387
|
+
|
|
1388
|
+
############# End of PSGD dense matrix Newton-type preconditioner #############
|
|
1389
|
+
|
|
1390
|
+
""" end of psgd """
|