torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1390 @@
1
+ # pylint:disable=not-callable
2
+ # this file is from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py ver from Sept., 2025
3
+ # with few minor modifications (like passing Q balancing probability)
4
+ """
5
+ The new PSGD-Kron Newton/Whitening preconditioners support five kinds of local coordinates for updating Q:
6
+
7
+ QUAD): It's a specific form for updating Q to ensure that Q > 0 (thus Q is symmetric/Hermitian).
8
+ This is one of the recommended choices for fitting Q.
9
+
10
+ QEQ): dQ = Q * mathcal{E} * Q
11
+ This leads to another simple way for updating Q (Q is in the general linear group).
12
+ It's another recommended choice for fitting Q.
13
+
14
+ Q0p5EQ1p5): dQ = Q^0.5 * mathcal{E} * Q^1.5
15
+ One more recommended choice for fitting Q.
16
+ An online orthogonal Procrustes problem solver is used to keep Q approximately SPD.
17
+
18
+ EQ): dQ = mathcal{E} * Q
19
+ This choice recovers the old PSGD way for updating Q in Lie groups (Q is triangular).
20
+ Its main drawback is that triangualr solvers are required for updating Q.
21
+
22
+ QEP): dQ = Q * mathcal{E} * P
23
+ This last choice works very well if it does. Q is in the general linear group.
24
+ But, one drawback is that Q might get stuck around ill-conditioned matrices (not strongly convex).
25
+
26
+ The QUAD formulae can be used to update P directly (see older commit 0fc33cd).
27
+ I call this choice QUAD4P. It still is a good choice for optimization with single precision.
28
+ Unlike QUAD, QUAD4P does not work well with half precision. Use it with caution.
29
+
30
+ The PSGD-LRA Newton/Whitening preconditioners still adopt local coordinate dQ = mathcal{E} * Q,
31
+ and needs a small linear solver to update the preconditioner.
32
+
33
+ I also keep the PSGD dense matrix Newton-type preconditioner here to illustrate the math.
34
+ It supports all the five methods for updating Q,
35
+ and can be a good alternative to the BFGS like quasi-Newton optimizers as no line search is required.
36
+
37
+ Xi-Lin Li, lixilinx@gmail.com; last updated in Sept., 2025.
38
+ Main refs: https://arxiv.org/abs/1512.04202; https://arxiv.org/abs/2402.11858.
39
+ """
40
+
41
+ from typing import TYPE_CHECKING, cast
42
+
43
+ import torch
44
+
45
+ from ....utils.python_tools import LazyLoader
46
+
47
+ opt_einsum = LazyLoader("opt_einsum")
48
+
49
+ if TYPE_CHECKING:
50
+ import opt_einsum as _opt_einsum
51
+ opt_einsum = cast(_opt_einsum, opt_einsum)
52
+
53
+ def norm_lower_bound_spd(A, k=32, half_iters=2):
54
+ """
55
+ Returns a cheap lower bound for the spectral norm of a symmetric positive definite matrix A, where,
56
+ k: the dim of subspace, suggesting 128 for bfloat16, 32 for float32 and 4 for float64 (tested on my laptop 4070 GPU);
57
+ half_iters: half of the number of subspace iterations, suggesting 2.
58
+ A rough norm estimation is good, and we don't orthonormaliz the subspace vectors.
59
+
60
+ The initial noise space V is rotated such that its centroid aligns with the largest row of A.
61
+ Hence, each row of V and the largest row of A has an angle about acos(1/sqrt(k)) when k << dim(A).
62
+ This feature makes the subspace iteration more robust for large matrices with very low rank.
63
+ A simplified branchless approximate implementation is provided here.
64
+ """
65
+ smallest_normal = torch.finfo(A.dtype).smallest_normal
66
+ normalizing_factor = A.diagonal().real.amax() + smallest_normal
67
+ A = A / normalizing_factor # (complex tensor) / (subnormal number) could produce inf or nan unexpectedly
68
+ j = torch.argmax(torch.linalg.vector_norm(A, dim=1))
69
+ V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device)
70
+ V = A[j] + torch.sgn(torch.sum(A[j] * V.conj(), dim=1, keepdim=True)) * V # torch.sign for real
71
+ for _ in range(half_iters):
72
+ V = V @ A
73
+ V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + smallest_normal
74
+ V = V @ A
75
+ return normalizing_factor * torch.amax(torch.linalg.vector_norm(V, dim=1))
76
+
77
+
78
+ def norm_lower_bound_skh(A, k=32, half_iters=2):
79
+ """
80
+ Returns a cheap lower bound for the spectral norm of a skew-Hermitian matrix A,
81
+ k: the dim of subspace, suggesting 128 for bfloat16, 32 for float32 and 4 for float64 (tested on my laptop 4070 GPU);
82
+ half_iters: half of the number of subspace iterations, suggesting 2.
83
+ A rough norm estimation is good, and we don't orthonormaliz the subspace vectors.
84
+
85
+ The initial noise space V is rotated such that its centroid aligns with the largest row of A.
86
+ Hence, each row of V and the largest row of A has an angle about acos(1/sqrt(k)) when k << dim(A).
87
+ This feature makes the subspace iteration more robust for large matrices with very low rank.
88
+ A simplified branchless approximate implementation is provided here.
89
+ """
90
+ smallest_normal = torch.finfo(A.dtype).smallest_normal
91
+ normalizing_factor = A.abs().amax() + smallest_normal
92
+ A = A / normalizing_factor # (complex tensor) / (subnormal number) could produce inf or nan unexpectedly
93
+ j = torch.argmax(torch.linalg.vector_norm(A, dim=1))
94
+ V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device)
95
+ V = A[j] + torch.sgn(torch.sum(A[j] * V.conj(), dim=1, keepdim=True)) * V # torch.sign for real
96
+ for _ in range(half_iters):
97
+ V = V @ A
98
+ V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + smallest_normal
99
+ V = V @ A
100
+ return normalizing_factor * torch.amax(torch.linalg.vector_norm(V, dim=1))
101
+
102
+
103
+ def lift2single(x):
104
+ # lift half or lower precision to single precision; leave single or higher precision unchanged
105
+ return x.to(torch.float32) if torch.finfo(x.dtype).eps > 1e-6 else x
106
+
107
+
108
+ def procrustes_step(Q, max_step_size=1/8):
109
+ """
110
+ A in-place (update Q directly) online solver for the orthogonal Procrustes problem,
111
+ min_U || U Q - I ||_F, s.t. U^H U = I
112
+ by rotating Q as exp(a R) Q, where R = Q^H - Q is the generator and ||a R|| < 1.
113
+
114
+ Do not set max_step_size > 1/4 as we only expand exp(a R) to its 2nd term.
115
+
116
+ Note that U(n) is connected and such rotations can make most complex Q SPD except for convergence to saddle points.
117
+ However, O(n) is not connected. Hence, such SO(n) rotations can only make real Q with det(Q) > 0 SPD.
118
+
119
+ We have simplified the original implementation. The one branch here is necessary for line search.
120
+ """
121
+ R = Q.H - Q
122
+ R /= norm_lower_bound_skh(R) + torch.finfo(R.dtype).smallest_normal # normalize R as typically it's too small
123
+ RQ = R @ Q
124
+ RRQ = R @ RQ
125
+ tr_RQ = RQ.diagonal().real.sum() # tr_RQ >=0 by theory; torch.trace not implemented for CPU bfloat16, so sum(diag())
126
+ tr_RRQ = RRQ.diagonal().real.sum() # line search is needed if tr_RRQ < 0
127
+ a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size), max_step_size)
128
+ Q.add_(a * (RQ + 0.5 * a * RRQ))
129
+
130
+
131
+ ############# Begin of PSGD Kronecker product preconditioners #############
132
+
133
+
134
+ def init_kron(t, Scale=1.0, max_size=float("inf"), max_skew=1.0, dQ="QEQ"):
135
+ """
136
+ For a scalar or tensor t, we initialize its states (preconditioner Q and Lipschitz smoothness constant L),
137
+ and reusable contraction expressions for updating Q and preconditioning gradient.
138
+
139
+ 1, The preconditioner Q is initialized to
140
+ Q = Scale * I = Scale * kron(eye(t.shape[0]), eye(t.shape[1]), ...)
141
+ where the eye(.) may be replaced with diag(ones(.)) if that dim is too large, determined by max_size and max_skew.
142
+
143
+ The Lipschitz smoothness constant L for Q is initialized to zero.
144
+
145
+ 2, A series of enisum contract expressions. The following subscript examples are for a 5th order tensor.
146
+ 2.1, exprP is the expression for applying the Preconditioner on the gradient, e.g.,
147
+ 'aA,bB,cC,dD,eE,aα,bβ,cγ,dδ,eε,αβγδε->ABCDE'
148
+ 2.2, the i-th expression of exprGs is for the contraction of two tensors that only keeps the i-th dim, e.g.,
149
+ 'abCde,abγde->Cγ'
150
+ for i=2. It's useful for Gradient calculation.
151
+ 2.3, exprA is the expression for applying All the factors of Q on a tensor, e.g.,
152
+ 'aA,bB,cC,dD,eE,ABCDE->abcde'
153
+ 2.4, the i-th expression of exprQs is the expression for applying the i-th factor of Q on a tensor, e.g.,
154
+ 'Cγ,abγde->abCde'
155
+ for i=2.
156
+
157
+ Please check https://drive.google.com/file/d/1CEEq7A3_l8EcPEDa_sYtqr5aMLVeZWL7/view?usp=drive_link for notations and derivations.
158
+ """
159
+ if dQ == "QUAD4P": # the only case that we fit P directly; so square Scale
160
+ Scale = Scale ** 2
161
+ shape = t.shape
162
+ if len(shape)==0: # scalar
163
+ Q = [Scale * torch.ones_like(t),]
164
+ L = [lift2single(torch.zeros_like(t.real)),]
165
+ exprA = opt_einsum.contract_expression(",->", Q[0].shape, t.shape)
166
+ exprP = opt_einsum.contract_expression(",,->", Q[0].shape, Q[0].shape, t.shape)
167
+ exprGs = [opt_einsum.contract_expression(",->", t.shape, t.shape),]
168
+ exprQs = [opt_einsum.contract_expression(",->", Q[0].shape, t.shape),]
169
+ else: # tensor
170
+ if len(shape) > 26:
171
+ raise ValueError(f"Got tensor with dim {len(t.shape)}; einsum runs out of letters; replace 26 with larger numbers.")
172
+
173
+ scale = Scale ** (1/len(shape))
174
+
175
+ Q, L = [], []
176
+ exprGs, exprQs = [], []
177
+ piece1A, piece2A, piece3A = [], "", "" # used for getting the subscripts for exprA
178
+ piece1P, piece2P, piece3P, piece4P = [], [], "", "" # used for getting the subscripts for exprP
179
+ for i, size in enumerate(shape):
180
+ L.append(lift2single(torch.zeros([], dtype=t.real.dtype, device=t.device)))
181
+ if size <= 1 or size > max_size or size**2 > max_skew * t.numel():
182
+ # use diagonal matrix as preconditioner for this dim
183
+ Q.append(scale * torch.ones(size, dtype=t.dtype, device=t.device))
184
+
185
+ piece1A.append(opt_einsum.get_symbol(i))
186
+ piece2A = piece2A + opt_einsum.get_symbol(i)
187
+ piece3A = piece3A + opt_einsum.get_symbol(i)
188
+
189
+ piece1P.append(opt_einsum.get_symbol(i + 26))
190
+ piece2P.append(opt_einsum.get_symbol(i + 26))
191
+ piece3P = piece3P + opt_einsum.get_symbol(i + 26)
192
+ piece4P = piece4P + opt_einsum.get_symbol(i + 26)
193
+
194
+ piece1 = "".join([opt_einsum.get_symbol(i+26) if j==i else opt_einsum.get_symbol(j) for j in range(len(shape))])
195
+ subscripts = piece1 + "," + piece1 + "->" + opt_einsum.get_symbol(i+26)
196
+ exprGs.append(opt_einsum.contract_expression(subscripts, t.shape, t.shape))
197
+
198
+ subscripts = opt_einsum.get_symbol(i+26) + "," + piece1 + "->" + piece1
199
+ exprQs.append(opt_einsum.contract_expression(subscripts, Q[-1].shape, t.shape))
200
+ else: # use matrix preconditioner for this dim
201
+ Q.append(scale * torch.eye(size, dtype=t.dtype, device=t.device))
202
+
203
+ piece1A.append(opt_einsum.get_symbol(i) + opt_einsum.get_symbol(i + 26))
204
+ piece2A = piece2A + opt_einsum.get_symbol(i + 26)
205
+ piece3A = piece3A + opt_einsum.get_symbol(i)
206
+
207
+ a, b, c = opt_einsum.get_symbol(i), opt_einsum.get_symbol(i + 26), opt_einsum.get_symbol(i + 805)
208
+ piece1P.append(a + b)
209
+ piece2P.append(a + c)
210
+ piece3P = piece3P + c
211
+ piece4P = piece4P + b
212
+
213
+ piece1 = "".join([opt_einsum.get_symbol(i+26) if j==i else opt_einsum.get_symbol(j) for j in range(len(shape))])
214
+ piece2 = "".join([opt_einsum.get_symbol(i+805) if j==i else opt_einsum.get_symbol(j) for j in range(len(shape))])
215
+ subscripts = piece1 + "," + piece2 + "->" + opt_einsum.get_symbol(i+26) + opt_einsum.get_symbol(i+805)
216
+ exprGs.append(opt_einsum.contract_expression(subscripts, t.shape, t.shape))
217
+
218
+ subscripts = opt_einsum.get_symbol(i+26) + opt_einsum.get_symbol(i+805) + "," + piece2 + "->" + piece1
219
+ exprQs.append(opt_einsum.contract_expression(subscripts, Q[-1].shape, t.shape))
220
+
221
+ subscripts = ",".join(piece1A) + "," + piece2A + "->" + piece3A
222
+ exprA = opt_einsum.contract_expression(subscripts, *[q.shape for q in Q], t.shape)
223
+
224
+ subscripts = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
225
+ exprP = opt_einsum.contract_expression(subscripts, *[q.shape for q in Q], *[q.shape for q in Q], t.shape)
226
+
227
+ exprGs, exprQs = tuple(exprGs), tuple(exprQs)
228
+ if dQ == "QEP":
229
+ return [[Q, L], (exprP, exprGs, exprQs)]
230
+ elif dQ == "EQ":
231
+ return [[Q, L], (exprP, exprGs, exprA)]
232
+ elif (dQ == "QEQ") or (dQ == "QUAD") or (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"):
233
+ return [[Q, L], (exprP, exprGs)]
234
+ else: # the only case that we fit P directly
235
+ assert dQ == "QUAD4P", "Invalid choice for dQ"
236
+ return [[Q, L], (exprA, exprGs)]
237
+
238
+
239
+ def balance_kron_precond(Q):
240
+ """
241
+ In place balancing the dynamic ranges of the factors of Q to avoid over/under-flow.
242
+ """
243
+ order = len(Q) # order of tensor or the number of factors in Q
244
+ if order>1:
245
+ norms = [torch.max(torch.abs(q)) for q in Q]
246
+ gmean = torch.prod(torch.stack(norms))**(1/order) # geometric mean
247
+ for i, q in enumerate(Q):
248
+ q.mul_(gmean/norms[i])
249
+
250
+
251
+ def update_precond_kron_eq(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, balance_prob=0.01):
252
+ """
253
+ The raw function for updating the Kron preconditioner Q and Lipschitz smoothness constant L with pair (V, Hvp),
254
+ where Q is update as dQ = E*Q,
255
+ the pair (V, Hvp) can be (vector, hess-vector-prod) or (randn, gradient/momentum).
256
+ The damping logic is not included here.
257
+ """
258
+ Q, L = QL
259
+ _, exprGs, exprA = exprs
260
+
261
+ def solve_triangular_right(B, A):
262
+ # return B @ inv(A)
263
+ if B.dim()>1:
264
+ return torch.linalg.solve_triangular(lift2single(A), lift2single(B), upper=True, left=False).to(B.dtype)
265
+ else: # torch.linalg.solve_triangular complains if B.dim() < 2. So insert None.
266
+ return (torch.linalg.solve_triangular(lift2single(A), lift2single(B[None,:]), upper=True, left=False)[0]).to(B.dtype)
267
+
268
+ A = exprA(*Q, Hvp)
269
+
270
+ order = V.dim()
271
+ p = list(range(order))
272
+ conjB = torch.permute(V.conj(), p[1:] + p[:1]) # permute dims like [0,1,2,3,4] -> [1,2,3,4,0]
273
+ for i, q in enumerate(Q):
274
+ conjB = conjB/q if q.dim()<2 else solve_triangular_right(conjB, q)
275
+ if i < order - 1: # transpose dims like [1,2,3,4,0]->[0,2,3,4,1]->[0,1,3,4,2]->[0,1,2,4,3]->[0,1,2,3,4]
276
+ conjB = torch.transpose(conjB, i, order - 1)
277
+
278
+ for i, q in enumerate(Q):
279
+ term1 = exprGs[i](A, A.conj())
280
+ term2 = exprGs[i](conjB.conj(), conjB)
281
+
282
+ if q.dim() < 2: # q is a diagonal matrix or scalar preconditioner
283
+ ell = torch.max(torch.real(term1 + term2))
284
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
285
+ q.sub_(lr/L[i] * (term1 - term2) * q) # q.mul_(1 - lr/L[i] * (term1 - term2)): larger roundoff errors
286
+ else: # q is a matrix preconditioner
287
+ ell = norm_lower_bound_spd(term1 + term2)
288
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
289
+ q.sub_(lr/L[i] * torch.triu(term1 - term2) @ q)
290
+
291
+ if torch.rand([]) < balance_prob: # balance factors of Q
292
+ balance_kron_precond(Q)
293
+
294
+
295
+ def precond_grad_kron(QL, exprs, G):
296
+ """
297
+ Precondition gradient G with Kron preconditioner Q.
298
+ """
299
+ Q, exprP = QL[0], exprs[0]
300
+ return exprP(*[q.conj() for q in Q], *Q, G)
301
+
302
+
303
+ def update_precond_kron_whiten_eq(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
304
+ """
305
+ Update the Kron preconditioner Q as dQ = E*Q.
306
+ """
307
+ V = torch.randn_like(G)
308
+ update_precond_kron_eq(QL, exprs, V, G + damping*V, lr=lr, betaL=betaL, balance_prob=balance_prob)
309
+
310
+
311
+ def update_precond_kron_whiten_qep(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=...):
312
+ """
313
+ Update the Kron preconditioner Q as dQ = Q*E*P.
314
+ """
315
+ Q, L = QL
316
+ exprP, exprGs, exprQs = exprs
317
+
318
+ # balancing is not optional as L for each factor is not scaling invariant
319
+ balance_kron_precond(Q)
320
+
321
+ total_numel = G.numel()
322
+ Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
323
+ for i, q in enumerate(Q):
324
+ QPg = exprQs[i](q, Pg)
325
+ term1 = exprGs[i](QPg, QPg.conj())
326
+ if q.dim() < 2: # diagonal or scalar Q
327
+ term2 = total_numel/q.numel() * q * q.conj()
328
+ ell = torch.max(torch.real(term1 + term2))
329
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
330
+ q.mul_(1 - lr/L[i] * (term1 - term2))
331
+ else: # matrix Q
332
+ term2 = total_numel/q.shape[0] * q @ q.H
333
+ ell = norm_lower_bound_spd(term1 + term2)
334
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
335
+ q.sub_(lr/L[i] * (term1 - term2) @ q)
336
+
337
+
338
+ def update_precond_kron_whiten_qeq(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
339
+ """
340
+ Update the Kron preconditioner Q as dQ = Q*E*Q.
341
+ """
342
+ Q, L = QL
343
+ exprP, exprGs = exprs
344
+
345
+ total_numel = G.numel()
346
+ Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
347
+ for i, q in enumerate(Q):
348
+ term1 = exprGs[i](Pg, Pg.conj())
349
+ if q.dim() < 2: # diagonal or scalar Q
350
+ term2 = total_numel/q.numel() # times I
351
+ ell = torch.max(torch.real(term1)) + term2
352
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
353
+ q.mul_(1 - lr/L[i] * (term1 - term2))
354
+ else: # matrix Q
355
+ term2 = total_numel/q.shape[0] # times I
356
+ ell = norm_lower_bound_spd(term1) + term2
357
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
358
+ q.sub_(lr/L[i] * (q @ term1 - q * term2))
359
+
360
+ if torch.rand([]) < balance_prob: # balance factors of Q
361
+ balance_kron_precond(Q)
362
+
363
+
364
+ def update_precond_kron_whiten_q0p5eq1p5(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
365
+ """
366
+ Update the Kron preconditioner Q as dQ = Q^0.5 * E * Q^1.5.
367
+ """
368
+ Q, L = QL
369
+ exprP, exprGs = exprs
370
+
371
+ total_numel = G.numel()
372
+ Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
373
+ for i, q in enumerate(Q):
374
+ term1 = exprGs[i](Pg, Pg.conj())
375
+ if q.dim() < 2: # diagonal or scalar Q
376
+ term2 = total_numel/q.numel() # times I
377
+ ell = torch.max(torch.real(term1)) + term2
378
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
379
+ q.mul_(1 - lr/L[i] * (term1 - term2))
380
+ else: # matrix Q
381
+ term2 = total_numel/q.shape[0] # times I
382
+ ell = norm_lower_bound_spd(term1) + term2
383
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
384
+ q.sub_(lr/L[i] * (term1 @ q - term2 * q))
385
+ procrustes_step(q)
386
+
387
+ if torch.rand([]) < balance_prob: # balance factors of Q
388
+ balance_kron_precond(Q)
389
+
390
+
391
+ def update_precond_kron_whiten_quad(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
392
+ """
393
+ Update the Kron preconditioner Q with a quadratic form.
394
+ """
395
+ Q, L = QL
396
+ exprP, exprGs = exprs
397
+
398
+ total_numel = G.numel()
399
+ Pg = exprP(*[q.conj() for q in Q], *Q, G + damping*torch.randn_like(G))
400
+ for i, q in enumerate(Q):
401
+ term1 = exprGs[i](Pg, Pg.conj())
402
+ if q.dim() < 2: # diagonal or scalar Q
403
+ term2 = total_numel/q.numel() # times I
404
+ ell = torch.max(torch.real(term1)) + term2
405
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
406
+ gain = 1 - lr/2/L[i] * (term1 - term2)
407
+ q.mul_(gain * gain)
408
+ else: # matrix Q
409
+ term2 = total_numel/q.shape[0] # times I
410
+ ell = norm_lower_bound_spd(term1) + term2
411
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
412
+ p = q - lr/2/L[i] * (term1 @ q - term2 * q)
413
+ p = p - lr/2/L[i] * (p @ term1 - p * term2)
414
+ q.copy_((p + p.H)/2) # p must be symmetric/hermitian
415
+
416
+ if torch.rand([]) < balance_prob: # balance factors of Q
417
+ balance_kron_precond(Q)
418
+
419
+
420
+ def update_precond_kron_whiten_quad4p(QL, exprs, G, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
421
+ """
422
+ Almost the same as function update_precond_kron_whiten_quad except that fitting P directly.
423
+ This is the only case that we fit P directly (Q here is P). Vulnerable to numerical errors.
424
+ """
425
+ Q, L = QL
426
+ exprA, exprGs = exprs
427
+
428
+ total_numel = G.numel()
429
+ Pg = exprA(*Q, G + damping*torch.randn_like(G)) # Q actually is P; so just applying all its factors once.
430
+ for i, q in enumerate(Q):
431
+ term1 = exprGs[i](Pg, Pg.conj())
432
+ if q.dim() < 2: # diagonal or scalar Q
433
+ term2 = total_numel/q.numel() # times I
434
+ ell = torch.max(torch.real(term1)) + term2
435
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
436
+ gain = 1 - lr/L[i] * (term1 - term2)
437
+ q.mul_(gain * gain)
438
+ else: # matrix Q
439
+ term2 = total_numel/q.shape[0] # times I
440
+ ell = norm_lower_bound_spd(term1) + term2
441
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
442
+ p = q - lr/L[i] * (term1 @ q - term2 * q)
443
+ p = p - lr/L[i] * (p @ term1 - p * term2)
444
+ q.copy_((p + p.H)/2) # p must be symmetric/hermitian
445
+
446
+ if torch.rand([]) < balance_prob: # balance factors of Q
447
+ balance_kron_precond(Q)
448
+
449
+
450
+ class KronWhiten:
451
+ """
452
+ Implements the PSGD optimizer with the Kronecker product gradient/momentum whitening preconditioner.
453
+ Most of the time, the hyperparameter name says it all. Here are some comments on a few key hyperparameters.
454
+
455
+ 1, preconditioner_max_size and preconditioner_max_skew. These two together control the complexity of the preconditioners.
456
+ For example, we are to precondition a 2D gradient with shape 10 x 50.
457
+ With preconditioner_max_size 20, we use a dense preconditioner for the first dim since 10 <= 20 and diagonal preconditioner for the second dim since 50 > 20.
458
+ With preconditioner_max_skew 1.5, we use a dense preconditioner for the first dim since 10/50 <= 1.5 and diagonal preconditioner for the second dim since 50/10 > 1.5.
459
+
460
+ 2, grad_clip_max_amp, betaL and damping. These three together help to stabilize the training.
461
+ PSGD here tries to normalize the gradients to unit amplitude. This can be problematic when gradients approach zeros.
462
+ The most effective way is to clip the preconditioned gradients if their amplitudes exceed grad_clip_max_amp, say 1.0.
463
+ Another way is to damp and upper bound the fitted preconditioner such that P < eye/damping.
464
+ For extremely sparse gradients, increasing betaL (say to 0.999) helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
465
+
466
+ 3, Lastly, dQ is for the selection of geometry for preconditioner update. QEQ, QUAD and Q0p5EQ1p5 all are good choices.
467
+ Q is initialized to preconditioner_init_scale * eye. Boolean setting whiten_grad decides to whiten whether the gradient or momentum.
468
+ Always good to check https://arxiv.org/abs/2402.11858 for math details.
469
+ """
470
+ def __init__(self, params_with_grad,
471
+ preconditioner_max_size=float("inf"), preconditioner_max_skew=1.0, preconditioner_init_scale:float|None=None,
472
+ lr_params=0.001, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
473
+ grad_clip_max_amp=float("inf"), preconditioner_update_probability=1.0, whiten_grad=True, dQ="Q0.5EQ1.5"):
474
+ # mutable members
475
+ self.lr_params = lr_params
476
+ self.lr_preconditioner = lr_preconditioner
477
+ self.betaL = betaL # beta for the Lipschitz smoothness constant estimation; set to a large value for sparse gradients
478
+ self.damping = damping # to damp and upper bound the preconditioner such that P < eye/damping
479
+ self.momentum = momentum if (0<momentum<1) else 0.0
480
+ self.grad_clip_max_amp = grad_clip_max_amp # clip grad once its average amplitude exceeds this max amplitude setting
481
+ self.preconditioner_update_probability = preconditioner_update_probability
482
+ # protected members
483
+ self._preconditioner_max_size = preconditioner_max_size
484
+ self._preconditioner_max_skew = preconditioner_max_skew
485
+ params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
486
+ self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
487
+ self._num_params = sum([p.numel() for p in self._params_with_grad])
488
+ if preconditioner_init_scale is None:
489
+ self._QLs_exprs = None # initialize on the fly
490
+ print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
491
+ else:
492
+ self._QLs_exprs = [init_kron(p.squeeze(), preconditioner_init_scale, preconditioner_max_size, preconditioner_max_skew, dQ) for p in self._params_with_grad]
493
+ self._ms, self._counter_m = None, 0 # momentum buffers and counter
494
+ self._whiten_grad = whiten_grad # set to False to whiten momentum.
495
+ if not whiten_grad:
496
+ assert self.momentum > 0, "Cannot whiten momentum if the momentum setting is zero."
497
+ print(f"Recommend to reduce lr_params by {int(((1 + momentum)/(1 - momentum))**0.5)} times")
498
+ self._dQ = dQ
499
+ if dQ == "QUAD4P": # the only case that we fit P directly
500
+ assert max([torch.finfo(p.dtype).eps for p in self._params_with_grad]) < 1e-6, "Directly fitting P needs at least single precision"
501
+ self._update_precond = update_precond_kron_whiten_quad4p
502
+ self._precond_grad = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
503
+ else:
504
+ self._precond_grad = precond_grad_kron
505
+ if dQ == "QEP":
506
+ self._update_precond = update_precond_kron_whiten_qep
507
+ elif dQ == "EQ":
508
+ self._update_precond = update_precond_kron_whiten_eq
509
+ elif dQ == "QEQ":
510
+ self._update_precond = update_precond_kron_whiten_qeq
511
+ elif dQ == "QUAD":
512
+ self._update_precond = update_precond_kron_whiten_quad
513
+ else:
514
+ assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), "Invalid choice for dQ"
515
+ self._update_precond = update_precond_kron_whiten_q0p5eq1p5
516
+
517
+
518
+ @torch.no_grad()
519
+ def step(self, closure):
520
+ """
521
+ Performs one step of PSGD with the Kronecker product gradient/momentum whitening preconditioner.
522
+ """
523
+ with torch.enable_grad():
524
+ closure_returns = closure()
525
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
526
+ grads = [g.squeeze() for g in torch.autograd.grad(loss, self._params_with_grad)]
527
+
528
+ if self._QLs_exprs is None:
529
+ scale = max([torch.mean((torch.abs(g))**4) for g in grads])
530
+ scale = (scale + self.damping**4)**(-1/8)
531
+ self._QLs_exprs = [init_kron(g, scale, self._preconditioner_max_size, self._preconditioner_max_skew, self._dQ) for g in grads]
532
+
533
+ if self.momentum > 0:
534
+ beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
535
+ self._counter_m += 1
536
+ if self._ms is None:
537
+ self._ms = [torch.zeros_like(g) for g in grads]
538
+
539
+ [m.mul_(beta).add_(g, alpha=1 - beta) for (m, g) in zip(self._ms, grads)]
540
+ else:
541
+ self._ms, self._counter_m = None, 0
542
+
543
+ if torch.rand([]) < self.preconditioner_update_probability: # update Q
544
+ if self._whiten_grad: # Q whitens gradient
545
+ [self._update_precond(*QL_exprs, g, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
546
+ for (QL_exprs, g) in zip(self._QLs_exprs, grads)]
547
+ else: # Q whitens momentum
548
+ [self._update_precond(*QL_exprs, m, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
549
+ for (QL_exprs, m) in zip(self._QLs_exprs, self._ms)]
550
+
551
+ if self.momentum > 0: # precondition momentum
552
+ pre_grads = [self._precond_grad(*QL_exprs, m) for (QL_exprs, m) in zip(self._QLs_exprs, self._ms)]
553
+ else: # precondition gradient
554
+ pre_grads = [self._precond_grad(*QL_exprs, g) for (QL_exprs, g) in zip(self._QLs_exprs, grads)]
555
+
556
+ lr = self.lr_params
557
+ if self.grad_clip_max_amp < float("inf"): # clip preconditioned gradient
558
+ avg_amp = torch.sqrt(torch.real(sum([torch.sum(g*g.conj()) for g in pre_grads]))/self._num_params)
559
+ if avg_amp > self.grad_clip_max_amp:
560
+ lr = lr * self.grad_clip_max_amp / avg_amp
561
+
562
+ # Update the parameters.
563
+ [param.subtract_(lr*g.view_as(param)) for (param, g) in zip(self._params_with_grad, pre_grads)]
564
+
565
+ # return whatever closure returns
566
+ return closure_returns
567
+
568
+
569
+ def update_precond_kron_newton_eq(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
570
+ """
571
+ Update the Kron Newton-type preconditioner Q as dQ = E*Q with a pair of vector and hvp, (V, Hvp).
572
+ """
573
+ update_precond_kron_eq(QL, exprs, V, Hvp + damping*torch.randn_like(Hvp), lr=lr, betaL=betaL, balance_prob=balance_prob)
574
+
575
+
576
+ def update_precond_kron_newton_qep(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=...):
577
+ """
578
+ Update the Kron Newton-type preconditioner Q as dQ = Q*E*P with a pair of vector and hvp, (V, Hvp).
579
+ """
580
+ Q, L = QL
581
+ exprP, exprGs, exprQs = exprs
582
+
583
+ # balancing is not optional as L for each factor is not scaling invariant
584
+ balance_kron_precond(Q)
585
+ Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
586
+
587
+ for i, q in enumerate(Q):
588
+ QPh = exprQs[i](q, Ph)
589
+ Qv = exprQs[i](q, V)
590
+ term1 = exprGs[i](QPh, QPh.conj())
591
+ term2 = exprGs[i](Qv, Qv.conj())
592
+ if q.dim() < 2: # diagonal or scalar Q
593
+ ell = torch.max(torch.real(term1 + term2))
594
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
595
+ q.mul_(1 - lr/L[i] * (term1 - term2))
596
+ else: # matrix Q
597
+ ell = norm_lower_bound_spd(term1 + term2)
598
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
599
+ q.sub_(lr/L[i] * (term1 - term2) @ q)
600
+
601
+
602
+ def update_precond_kron_newton_qeq(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
603
+ """
604
+ Update the Kron Newton-type preconditioner Q as dQ = Q*E*Q with a pair of vector and hvp, (V, Hvp).
605
+ """
606
+ Q, L = QL
607
+ exprP, exprGs = exprs
608
+ Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
609
+
610
+ for i, q in enumerate(Q):
611
+ term1 = exprGs[i](Ph, Ph.conj())
612
+ term2 = exprGs[i](V, V.conj())
613
+ if q.dim() < 2: # diagonal or scalar Q
614
+ ell = torch.max(torch.real(term1 + term2))
615
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
616
+ q.mul_(1 - lr/L[i] * (term1 - term2))
617
+ else: # matrix Q
618
+ ell = norm_lower_bound_spd(term1 + term2)
619
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
620
+ q.sub_(lr/L[i] * q @ (term1 - term2))
621
+
622
+ if torch.rand([]) < balance_prob: # balance factors of Q
623
+ balance_kron_precond(Q)
624
+
625
+
626
+ def update_precond_kron_newton_q0p5eq1p5(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
627
+ """
628
+ Update the Kron Newton-type preconditioner Q as dQ = Q^0.5 * E * Q^1.5 with a pair of vector and hvp, (V, Hvp).
629
+ """
630
+ Q, L = QL
631
+ exprP, exprGs = exprs
632
+ Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
633
+
634
+ for i, q in enumerate(Q):
635
+ term1 = exprGs[i](Ph, Ph.conj())
636
+ term2 = exprGs[i](V, V.conj())
637
+ if q.dim() < 2: # diagonal or scalar Q
638
+ ell = torch.max(torch.real(term1 + term2))
639
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
640
+ q.mul_(1 - lr/L[i] * (term1 - term2))
641
+ else: # matrix Q
642
+ ell = norm_lower_bound_spd(term1 + term2)
643
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
644
+ q.sub_(lr/L[i] * (term1 - term2) @ q)
645
+ procrustes_step(q)
646
+
647
+ if torch.rand([]) < balance_prob: # balance factors of Q
648
+ balance_kron_precond(Q)
649
+
650
+
651
+ def update_precond_kron_newton_quad(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
652
+ """
653
+ Update the Kron Newton-type preconditioner Q with a quadratic form for dQ and pair of vector and hvp, (V, Hvp).
654
+ """
655
+ Q, L = QL
656
+ exprP, exprGs = exprs
657
+ Ph = exprP(*[q.conj() for q in Q], *Q, Hvp + damping*torch.randn_like(Hvp))
658
+
659
+ for i, q in enumerate(Q):
660
+ term1 = exprGs[i](Ph, Ph.conj())
661
+ term2 = exprGs[i](V, V.conj())
662
+ if q.dim() < 2: # diagonal or scalar Q
663
+ ell = torch.max(torch.real(term1 + term2))
664
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
665
+ gain = 1 - lr/2/L[i] * (term1 - term2)
666
+ q.mul_(gain * gain)
667
+ else: # matrix Q
668
+ ell = norm_lower_bound_spd(term1 + term2)
669
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
670
+ err = lr/2/L[i] * (term1 - term2)
671
+ p = q - err @ q # p = q - lr/L[i]/2 * (term1 - term2) @ q
672
+ p = p - p @ err # p = p - lr/L[i]/2 * p @ (term1 - term2)
673
+ q.copy_((p + p.H)/2) # p must be symmetric or hermitian
674
+
675
+ if torch.rand([]) < balance_prob: # balance factors of Q
676
+ balance_kron_precond(Q)
677
+
678
+
679
+ def update_precond_kron_newton_quad4p(QL, exprs, V, Hvp, lr=0.1, betaL=0.9, damping=1e-9, balance_prob=0.01):
680
+ """
681
+ Almost the same as function update_precond_kron_newton_quad except that we fit P directly.
682
+ This is the only case that fits P directly (Q here is P). It's vulnerable to numerical errors.
683
+ """
684
+ Q, L = QL
685
+ exprA, exprGs = exprs
686
+ Ph = exprA(*Q, Hvp + damping*torch.randn_like(Hvp)) # Q actually is P; so only need to apply its factors once.
687
+
688
+ for i, q in enumerate(Q):
689
+ term1 = exprGs[i](Ph, Ph.conj())
690
+ term2 = exprGs[i](V, V.conj())
691
+ if q.dim() < 2: # diagonal or scalar Q
692
+ ell = torch.max(torch.real(term1 + term2))
693
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
694
+ gain = 1 - lr/L[i] * (term1 - term2)
695
+ q.mul_(gain * gain)
696
+ else: # matrix Q
697
+ ell = norm_lower_bound_spd(term1 + term2)
698
+ L[i].copy_(torch.max(betaL*L[i] + (1 - betaL)*ell, ell))
699
+ err = lr/L[i] * (term1 - term2)
700
+ p = q - err @ q # p = q - lr/L[i] * (term1 - term2) @ q
701
+ p = p - p @ err # p = p - lr/L[i] * p @ (term1 - term2)
702
+ q.copy_((p + p.H)/2) # p must be symmetric or hermitian
703
+
704
+ if torch.rand([]) < balance_prob: # balance factors of Q
705
+ balance_kron_precond(Q)
706
+
707
+
708
+ class KronNewton:
709
+ """
710
+ Implements the Kronecker product Newton-type preconditioner as a class.
711
+ Most of the time, the hyperparameter name says it all. Here are some comments on a few key parameters.
712
+
713
+ 1, preconditioner_max_size and preconditioner_max_skew. These two together control the complexity of the preconditioners.
714
+ For example, we are to precondition a 2D gradient with shape 10 x 50.
715
+ With preconditioner_max_size 20, we use a dense preconditioner for the first dim since 10 <= 20 and diagonal preconditioner for the second dim since 50 > 20.
716
+ With preconditioner_max_skew 1.5, we use a dense preconditioner for the first dim since 10/50 <= 1.5 and diagonal preconditioner for the second dim since 50/10 > 1.5.
717
+
718
+ 2, grad_clip_max_norm, betaL and damping. These three together help to stabilize the training.
719
+ The grad_clip_max_norm is used to clip the preconditioned gradient to stabilize the optimization as in the classic trust region method.
720
+ Setting damping is used to damp and upper bound the fitted preconditioner such that P < eye/damping.
721
+ For extremely sparse Hess-vector-prod, a large betaL (say 0.999) helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
722
+
723
+ 3, exact_hessian_vector_product.
724
+ By setting this flag to False, the finite difference method will be used for Hvp approximation.
725
+ Be cautious with the finite difference method (possible numerical issues; the closure must behave like a pure function).
726
+
727
+ 4, Lastly, dQ is for the selection of geometry for preconditioner update. QEQ, QUAD and Q0p5EQ1p5 all are good choices.
728
+ Both lr_params and lr_preconditioner are normalized learning rates.
729
+ Q is initialized to preconditioner_init_scale * eye.
730
+ Always good to check https://arxiv.org/abs/2402.11858 for math details.
731
+ """
732
+ def __init__(self, params_with_grad, preconditioner_max_size=float("inf"), preconditioner_max_skew=1.0, preconditioner_init_scale:float|None=None,
733
+ lr_params=0.01, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
734
+ grad_clip_max_norm=float("inf"), preconditioner_update_probability=1.0,
735
+ exact_hessian_vector_product=True, dQ="Q0.5EQ1.5"):
736
+ # mutable members
737
+ self.lr_params = lr_params
738
+ self.lr_preconditioner = lr_preconditioner
739
+ self.betaL = betaL # beta for Lipschitz smoothness constant estimation; set to a large value for sparse Hvp
740
+ self.damping = damping # used to damp and upper bound P as P < eye/damping
741
+ self.momentum = momentum if (0<momentum<1) else 0.0
742
+ self.grad_clip_max_norm = grad_clip_max_norm
743
+ self.preconditioner_update_probability = preconditioner_update_probability
744
+ # protected members
745
+ self._preconditioner_max_size = preconditioner_max_size
746
+ self._preconditioner_max_skew = preconditioner_max_skew
747
+ params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
748
+ self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
749
+ eps = max([torch.finfo(p.dtype).eps for p in self._params_with_grad])
750
+ self._delta_param_scale = eps ** 0.5
751
+ if preconditioner_init_scale is None:
752
+ self._QLs_exprs = None # initialize on the fly
753
+ print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
754
+ else:
755
+ self._QLs_exprs = [init_kron(p.squeeze(), preconditioner_init_scale, preconditioner_max_size, preconditioner_max_skew, dQ) for p in self._params_with_grad]
756
+ self._ms, self._counter_m = None, 0 # momentum buffers and counter
757
+ self._exact_hessian_vector_product = exact_hessian_vector_product
758
+ if not exact_hessian_vector_product:
759
+ print("FYI: Approximate Hvp with finite-difference method. Make sure that: 1) the closure behaves like a pure function; 2) delta param scale is proper.")
760
+ self._dQ = dQ
761
+ if dQ == "QUAD4P": # the only case that fits P directly
762
+ self._update_precond = update_precond_kron_newton_quad4p
763
+ self._precond_grad = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
764
+ assert eps < 1e-6, "Directly fitting P needs at least single precision"
765
+ else:
766
+ self._precond_grad = precond_grad_kron
767
+ if dQ == "QUAD":
768
+ self._update_precond = update_precond_kron_newton_quad
769
+ elif dQ == "QEP":
770
+ self._update_precond = update_precond_kron_newton_qep
771
+ elif dQ == "EQ":
772
+ self._update_precond = update_precond_kron_newton_eq
773
+ elif dQ == "QEQ":
774
+ self._update_precond = update_precond_kron_newton_qeq
775
+ else:
776
+ assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), "Invalid choice for dQ"
777
+ self._update_precond = update_precond_kron_newton_q0p5eq1p5
778
+
779
+
780
+ @torch.no_grad()
781
+ def step(self, closure):
782
+ """
783
+ Performs one step of PSGD with the Kronecker product Newton-type preconditioner.
784
+ """
785
+ if (torch.rand([]) < self.preconditioner_update_probability) or (self._QLs_exprs is None):
786
+ # evaluates gradients, Hessian-vector product, and updates the preconditioner
787
+ if self._exact_hessian_vector_product:
788
+ with torch.enable_grad():
789
+ closure_returns = closure()
790
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
791
+ grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)
792
+ vs = [torch.randn_like(p) for p in self._params_with_grad]
793
+ Hvs = torch.autograd.grad(grads, self._params_with_grad, vs) # this line also works for complex matrices
794
+ else: # approximate the Hessian-vector product via finite-difference formulae. Use it with cautions.
795
+ with torch.enable_grad():
796
+ closure_returns = closure()
797
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
798
+ grads = torch.autograd.grad(loss, self._params_with_grad)
799
+
800
+ vs = [torch.randn_like(p) for p in self._params_with_grad]
801
+ [p.add_(v, alpha=self._delta_param_scale) for (p, v) in zip(self._params_with_grad, vs)]
802
+ with torch.enable_grad():
803
+ perturbed_returns = closure()
804
+ perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
805
+ perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)
806
+ Hvs = [(perturbed_g - g)/self._delta_param_scale for (perturbed_g, g) in zip(perturbed_grads, grads)]
807
+ [p.sub_(v, alpha=self._delta_param_scale) for (p, v) in zip(self._params_with_grad, vs)] # remove the perturbation
808
+
809
+ if self._QLs_exprs is None: # initialize QLs on the fly if it is None
810
+ scale = (sum([torch.sum(torch.abs(v)**2) for v in vs])/sum([v.numel() for v in vs])) ** (1/4) # (mean(|v|^2))^(1/4)
811
+ scale = scale * (max([torch.mean((torch.abs(h))**4) for h in Hvs]) + self.damping**4) ** (-1/8) # (mean(|v|^2))^(1/4) * (mean(|h|^4))^(-1/8)
812
+ self._QLs_exprs = [init_kron(h.squeeze(), scale, self._preconditioner_max_size, self._preconditioner_max_skew, self._dQ) for h in Hvs]
813
+ # update preconditioner
814
+ [self._update_precond(*QL_exprs, v.squeeze(), h.squeeze(), lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
815
+ for (QL_exprs, v, h) in zip(self._QLs_exprs, vs, Hvs)]
816
+ else: # only evaluate the gradients
817
+ with torch.enable_grad():
818
+ closure_returns = closure()
819
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
820
+ grads = torch.autograd.grad(loss, self._params_with_grad)
821
+
822
+ grads = [g.squeeze() for g in grads]
823
+ if self.momentum > 0: # precondition the momentum
824
+ beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
825
+ self._counter_m += 1
826
+ if self._ms is None:
827
+ self._ms = [torch.zeros_like(g) for g in grads]
828
+
829
+ [m.mul_(beta).add_(g, alpha=1 - beta) for (m, g) in zip(self._ms, grads)]
830
+ pre_grads = [self._precond_grad(*QL_exprs, m) for (QL_exprs, m) in zip(self._QLs_exprs, self._ms)]
831
+ else: # precondition the gradient
832
+ self._ms, self._counter_m = None, 0 # clear the buffer and counter when momentum is set to zero
833
+ pre_grads = [self._precond_grad(*QL_exprs, g) for (QL_exprs, g) in zip(self._QLs_exprs, grads)]
834
+
835
+ lr = self.lr_params
836
+ if self.grad_clip_max_norm < float("inf"):
837
+ grad_norm = torch.sqrt(torch.real(sum([torch.sum(g*g.conj()) for g in pre_grads])))
838
+ if grad_norm > self.grad_clip_max_norm:
839
+ lr = lr * self.grad_clip_max_norm / grad_norm
840
+
841
+ # Update the parameters.
842
+ [param.subtract_(lr*g.view_as(param)) for (param, g) in zip(self._params_with_grad, pre_grads)]
843
+
844
+ # return whatever closure returns
845
+ return closure_returns
846
+
847
+
848
+ ############# End of PSGD Kronecker product preconditioners #############
849
+
850
+
851
+ ############# Begin of PSGD LRA (low rank approximation) preconditioners #############
852
+
853
+
854
+ def IpUVtmatvec(U, V, x):
855
+ """
856
+ Returns (I + U*V')*x. All variables are either matrices or column vectors.
857
+ """
858
+ return x + U.mm(V.t().mm(x))
859
+
860
+
861
+ def update_precond_lra(UVd, Luvd, v, h, lr=0.1, betaL=0.9):
862
+ """
863
+ The raw function for updating the LRA preconditioner Q = (I + U*V')*diag(d) with pair (v, h),
864
+ where h can a Hvp associated with v, or a gradient/momentum independent of v.
865
+ State variables (U, V, d) and their Lipschitz smoothness constant estimates (Lu, Lv, Ld) are updated inplace.
866
+ Damping logic is not implemented here.
867
+ Note that U, V, d, v, and h all are either matrices or column vectors.
868
+ """
869
+ U, V, d = UVd
870
+ Lu, Lv, Ld = Luvd
871
+
872
+ # Approximately balancing U and V such that U^T U = V^T V (exact balancing needs three EVDs)
873
+ UtU, VtV = U.t() @ U, V.t() @ V
874
+ trUtU, trVtV = torch.sum(UtU.diagonal()), torch.sum(VtV.diagonal())
875
+ rho = (trUtU/trVtV) ** (1/4) # will scale U and V as U <-- U/rho and V <-- V*rho
876
+ rho2 = rho * rho
877
+ E = 0.1 * (UtU/rho2 - VtV*rho2)/(trUtU/rho2 + trVtV*rho2) # errors after scaling U and V
878
+ E2 = 0.5 * E @ E # using this E2 term to make (I - E + E^2/2)(I + E + E^2/2) = (I + E^2/2)^2 - E^2 = I + E^4/4
879
+ U.div_(rho), V.mul_(rho) # scale U and V to have ||U||_F = ||V||_F
880
+ U.sub_(U @ (E - E2)), V.add_(V @ (E + E2)) # rotate (as tr(E)=0) U and V to approach U^TU = V^TV
881
+
882
+ Qh = IpUVtmatvec(U, V, d * h)
883
+ Ph = d*IpUVtmatvec(V, U, Qh)
884
+
885
+ IpVtU = V.t().mm(U)
886
+ IpVtU.diagonal().add_(1) # avoid forming matrix I explicitly
887
+ invQtv = v/d
888
+ LU, pivots, _ = torch.linalg.lu_factor_ex(lift2single(IpVtU))
889
+ invQtv = invQtv - V.mm(torch.linalg.lu_solve(LU, pivots, lift2single(U.t().mm(invQtv)), adjoint=True).to(V.dtype))
890
+ invPv = invQtv - U.mm(torch.linalg.lu_solve(LU, pivots, lift2single(V.t().mm(invQtv))).to(U.dtype))
891
+ invPv = invPv/d
892
+
893
+ # update d
894
+ Phh, vinvPv = Ph*h, v*invPv
895
+ ell = torch.max(torch.abs(Phh)) + torch.max(torch.abs(vinvPv))
896
+ Ld.copy_(torch.max(betaL*Ld + (1 - betaL)*ell, ell))
897
+ d.sub_(lr/Ld*(Phh - vinvPv)*d) # d.mul_(1 - lr/Ld*(Phh - vinvPv)): larger roundoff errors, unstable with bfloat16 and lr<<1
898
+
899
+ a, b = Qh, invQtv
900
+ if torch.rand([]) < 0.5: # only update U
901
+ atV = a.t().mm(V)
902
+ btV = b.t().mm(V)
903
+ atVVt = atV.mm(V.t())
904
+ btVVt = btV.mm(V.t())
905
+ ell = (torch.linalg.vector_norm(a)*torch.linalg.vector_norm(atVVt) +
906
+ torch.linalg.vector_norm(b)*torch.linalg.vector_norm(btVVt))
907
+ Lu.copy_(torch.max(betaL*Lu + (1 - betaL)*ell, ell))
908
+ U.sub_(lr/Lu * ( a.mm(atV.mm(IpVtU)) - b.mm(btV.mm(IpVtU)) ))
909
+ else: # only udate V
910
+ atU = a.t().mm(U)
911
+ btU = b.t().mm(U)
912
+ UUta = U.mm(atU.t())
913
+ UUtb = U.mm(btU.t())
914
+ ell = (torch.linalg.vector_norm(a)*torch.linalg.vector_norm(UUta) +
915
+ torch.linalg.vector_norm(b)*torch.linalg.vector_norm(UUtb))
916
+ Lv.copy_(torch.max(betaL*Lv + (1 - betaL)*ell, ell))
917
+ V.sub_(lr/Lv * ( (a + V.mm(atU.t())).mm(atU) - (b + V.mm(btU.t())).mm(btU) ))
918
+
919
+
920
+ def precond_grad_lra(UVd, g):
921
+ """
922
+ Precondition gradient g with Q = (I + U*V')*diag(d).
923
+ All variables here are either matrices or column vectors.
924
+ """
925
+ U, V, d = UVd
926
+ g = IpUVtmatvec(U, V, d * g)
927
+ g = d * IpUVtmatvec(V, U, g)
928
+ return g
929
+
930
+
931
+ def update_precond_lra_whiten(UVd, Luvd, g, lr=0.1, betaL=0.9, damping=1e-9):
932
+ """
933
+ Update the LRA whiten preconditioner.
934
+ """
935
+ v = torch.randn_like(g)
936
+ update_precond_lra(UVd, Luvd, v, g + damping*v, lr=lr, betaL=betaL)
937
+
938
+
939
+ class LRAWhiten:
940
+ """
941
+ Implements the PSGD LRA gradient/momentum whitening preconditioner as a class.
942
+ Most of the time, the hyperparameter name says it all. Here are some comments on a few key parameters.
943
+
944
+ 1, rank_of_approximation.
945
+ Preconditioner Q has a diagonal part and a low rank part, whose rank is decided by this setting.
946
+ Rank 0 reduces Q to a diagonal preconditioner.
947
+
948
+ 2, grad_clip_max_amp, betaL and damping. These three together help to stabilize the training.
949
+ PSGD here tries to normalize the gradients to unit amplitude. This can be problematic when gradients approach zeros.
950
+ The most effective way is to clip the preconditioned gradients when their amplitudes exceed grad_clip_max_amp, say 1.0.
951
+ Another way is to damp and upper bound the fitted preconditioner as P < eye/damping.
952
+ For extremely sparse gradient, increasing betaL (say to 0.999) also helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
953
+
954
+ 3, Lastly, Q is initialized to preconditioner_init_scale * eye.
955
+ Boolean setting whiten_grad decides to whiten whether the gradient or momentum.
956
+ Always good to check https://arxiv.org/abs/2402.11858 for math details.
957
+ """
958
+ def __init__(self, params_with_grad, rank_of_approximation:int=10, preconditioner_init_scale:float|None=None,
959
+ lr_params=0.001, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
960
+ grad_clip_max_amp=float("inf"), preconditioner_update_probability=1.0, whiten_grad=True):
961
+ # mutable members
962
+ self.lr_params = lr_params
963
+ self.lr_preconditioner = lr_preconditioner
964
+ self.betaL = betaL # set to a large betaL for sparse gradients
965
+ self.damping = damping # to damp and upper bound P as P < eye/damping
966
+ self.momentum = momentum if (0<momentum<1) else 0.0
967
+ self.grad_clip_max_amp = grad_clip_max_amp
968
+ self.preconditioner_update_probability = preconditioner_update_probability
969
+ # protected members
970
+ params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
971
+ self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
972
+ dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device
973
+ self._param_sizes = [torch.numel(param) for param in self._params_with_grad]
974
+ self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)
975
+ num_params = self._param_cumsizes[-1]
976
+ assert 0 <= rank_of_approximation < num_params, "Rank r should be in range [0, number of total parameters)"
977
+ self._UVd = [] # saves U, V and d
978
+ self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # U
979
+ self._UVd[0] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[0])
980
+ self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # V
981
+ self._UVd[1] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[1])
982
+ if preconditioner_init_scale is None:
983
+ print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
984
+ else:
985
+ self._UVd.append(torch.ones(num_params, 1, dtype=dtype, device=device) * preconditioner_init_scale)
986
+ self._Luvd = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
987
+ self._m, self._counter_m = None, 0 # momentum buffer and counter
988
+ self._whiten_grad = whiten_grad
989
+ if (not whiten_grad):
990
+ assert self.momentum > 0, "Cannot whiten momentum if the momentum setting is zero."
991
+ print(f"Recommend to reduce lr_params by {int(((1 + momentum)/(1 - momentum))**0.5)} times")
992
+
993
+
994
+ @torch.no_grad()
995
+ def step(self, closure):
996
+ """
997
+ Performs one step of the PSGD LRA gradient/momentum whitening optimizer.
998
+ """
999
+ with torch.enable_grad():
1000
+ closure_returns = closure()
1001
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
1002
+ grads = torch.autograd.grad(loss, self._params_with_grad)
1003
+
1004
+ # cat grads
1005
+ grad = torch.cat([torch.reshape(g, [-1, 1]) for g in grads]) # column vector
1006
+
1007
+ if len(self._UVd) < 3: # initialize d on the fly
1008
+ self._UVd.append((torch.mean(grad**4) + self.damping**4)**(-1/8) * torch.ones_like(grad))
1009
+
1010
+ if self.momentum > 0:
1011
+ beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
1012
+ self._counter_m += 1
1013
+ if self._m is None:
1014
+ self._m = torch.zeros_like(grad)
1015
+
1016
+ self._m.mul_(beta).add_(grad, alpha=1 - beta)
1017
+ else: # clear the momentum buffer and counter when momentum is set to zero
1018
+ self._m, self._counter_m = None, 0
1019
+
1020
+ if torch.rand([]) < self.preconditioner_update_probability: # update preconditioner
1021
+ if self._whiten_grad: # whitens gradient
1022
+ update_precond_lra_whiten(self._UVd, self._Luvd, grad, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
1023
+ else: # whitens momentum
1024
+ update_precond_lra_whiten(self._UVd, self._Luvd, self._m, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
1025
+
1026
+ if self.momentum > 0: # precondition momentum
1027
+ pre_grad = precond_grad_lra(self._UVd, self._m)
1028
+ else: # precondition gradient
1029
+ pre_grad = precond_grad_lra(self._UVd, grad)
1030
+
1031
+ lr = self.lr_params
1032
+ if self.grad_clip_max_amp < float("inf"): # clip preconditioned gradient
1033
+ amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
1034
+ if amp > self.grad_clip_max_amp:
1035
+ lr = lr * self.grad_clip_max_amp/amp
1036
+
1037
+ # update the parameters
1038
+ [param.subtract_(lr * pre_grad[j - i:j].view_as(param))
1039
+ for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]
1040
+
1041
+ # return whatever closure returns
1042
+ return closure_returns
1043
+
1044
+
1045
+ def update_precond_lra_newton(UVd, Luvd, v, h, lr=0.1, betaL=0.9, damping=1e-9):
1046
+ """
1047
+ Update the LRA Newton preconditioner.
1048
+ """
1049
+ update_precond_lra(UVd, Luvd, v, h + damping*torch.randn_like(h), lr=lr, betaL=betaL)
1050
+
1051
+
1052
+ class LRANewton:
1053
+ """
1054
+ Implements the PSGD LRA Newton-type preconditioner as a class.
1055
+ Most of the time, the hyperparameter name says it all. Here are some comments on a few key parameters.
1056
+
1057
+ 1, rank_of_approximation.
1058
+ Preconditioner Q has a diagonal part and a low rank part, whose rank is decided by this setting.
1059
+ Rank 0 reduces Q to a diagonal preconditioner.
1060
+
1061
+ 2, grad_clip_max_norm, betaL and damping. These three together help to stabilize the training.
1062
+ The grad_clip_max_norm is used to clip the preconditioned gradient to stabilize the optimization as in the classic trust region method.
1063
+ Setting damping is used to damp and upper bound the preconditioner as P < eye/damping.
1064
+ For extremely sparse hess-vector-prods, a large betaL (say 0.999) helps a lot, where betaL is the EMA factor for the L-smoothness constant (wrt Q) estimation.
1065
+
1066
+ 3, exact_hessian_vector_product.
1067
+ By setting this flag to False, the finite difference method will be used for Hvp approximation.
1068
+ Be cautious with the finite difference method (possible numerical issues; the closure must behave like a pure function).
1069
+
1070
+ 4, Lastly, Q is initialized to preconditioner_init_scale * eye.
1071
+ Both lr_params and lr_preconditioner are normalized learning rates.
1072
+ Always good to check https://arxiv.org/abs/2402.11858 for math details.
1073
+ """
1074
+ def __init__(self, params_with_grad, rank_of_approximation:int=10, preconditioner_init_scale:float|None=None,
1075
+ lr_params=0.01, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
1076
+ grad_clip_max_norm=float("inf"), preconditioner_update_probability=1.0,
1077
+ exact_hessian_vector_product=True):
1078
+ # mutable members
1079
+ self.lr_params = lr_params
1080
+ self.lr_preconditioner = lr_preconditioner
1081
+ self.betaL = betaL # set to a large betaL for sparse Hvp
1082
+ self.damping = damping # to damp and upper bound the preconditioner as P < eye/damping
1083
+ self.momentum = momentum if (0<momentum<1) else 0.0
1084
+ self.grad_clip_max_norm = grad_clip_max_norm
1085
+ self.preconditioner_update_probability = preconditioner_update_probability
1086
+ # protected members
1087
+ params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
1088
+ self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
1089
+ dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device
1090
+ self._delta_param_scale = torch.finfo(dtype).eps**0.5
1091
+ self._param_sizes = [torch.numel(param) for param in self._params_with_grad]
1092
+ self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)
1093
+ num_params = self._param_cumsizes[-1]
1094
+ assert 0 <= rank_of_approximation < num_params, "Rank r should be in range [0, number of total parameters)"
1095
+ self._UVd = [] # saves U, V and d
1096
+ self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # U
1097
+ self._UVd[0] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[0])
1098
+ self._UVd.append(torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device)) # V
1099
+ self._UVd[1] *= 0.1**0.5 / torch.linalg.vector_norm(self._UVd[1])
1100
+ if preconditioner_init_scale is None:
1101
+ print("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
1102
+ else:
1103
+ self._UVd.append(torch.ones(num_params, 1, dtype=dtype, device=device) * preconditioner_init_scale)
1104
+ self._Luvd = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
1105
+ self._m, self._counter_m = None, 0 # momentum buffer and counter
1106
+ self._exact_hessian_vector_product = exact_hessian_vector_product
1107
+ if not exact_hessian_vector_product:
1108
+ print("FYI: Approximate Hvp with finite-difference method. Make sure that: 1) the closure behaves like a pure function; 2) delta param scale is proper.")
1109
+
1110
+
1111
+ @torch.no_grad()
1112
+ def step(self, closure):
1113
+ """
1114
+ Performs one step of the PSGD LRA Newton optimizer.
1115
+ """
1116
+ if (torch.rand([]) < self.preconditioner_update_probability) or (len(self._UVd) < 3):
1117
+ # evaluates gradients, Hessian-vector product, and updates the preconditioner
1118
+ if self._exact_hessian_vector_product:
1119
+ with torch.enable_grad():
1120
+ closure_returns = closure()
1121
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
1122
+ grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)
1123
+ vs = [torch.randn_like(param) for param in self._params_with_grad]
1124
+ Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)
1125
+ else: # approximate Hessian-vector product via finite-difference formulae. Use it with cautions.
1126
+ with torch.enable_grad():
1127
+ closure_returns = closure()
1128
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
1129
+ grads = torch.autograd.grad(loss, self._params_with_grad)
1130
+
1131
+ vs = [torch.randn_like(param) for param in self._params_with_grad]
1132
+ [param.add_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
1133
+ with torch.enable_grad():
1134
+ perturbed_returns = closure()
1135
+ perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
1136
+ perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)
1137
+ Hvs = [(perturbed_g - g)/self._delta_param_scale for (perturbed_g, g) in zip(perturbed_grads, grads)]
1138
+ [param.sub_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
1139
+
1140
+ v = torch.cat([torch.reshape(v, [-1, 1]) for v in vs]) # column vector
1141
+ h = torch.cat([torch.reshape(h, [-1, 1]) for h in Hvs]) # column vector
1142
+ if len(self._UVd) < 3: # init d if it's not in the UVd list
1143
+ self._UVd.append((torch.mean(v*v))**(1/4) * (torch.mean(h**4) + self.damping**4)**(-1/8) * torch.ones_like(v))
1144
+
1145
+ # update preconditioner
1146
+ update_precond_lra_newton(self._UVd, self._Luvd, v, h, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
1147
+ else: # only evaluates the gradients
1148
+ with torch.enable_grad():
1149
+ closure_returns = closure()
1150
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
1151
+ grads = torch.autograd.grad(loss, self._params_with_grad)
1152
+
1153
+ # cat grads
1154
+ grad = torch.cat([torch.reshape(g, [-1, 1]) for g in grads]) # column vector
1155
+
1156
+ if self.momentum > 0: # precondition momentum
1157
+ beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
1158
+ self._counter_m += 1
1159
+ if self._m is None:
1160
+ self._m = torch.zeros_like(grad)
1161
+
1162
+ self._m.mul_(beta).add_(grad, alpha=1 - beta)
1163
+ pre_grad = precond_grad_lra(self._UVd, self._m)
1164
+ else: # precondition gradient
1165
+ self._m, self._counter_m = None, 0 # clear the buffer and counter when momentum is set to zero
1166
+ pre_grad = precond_grad_lra(self._UVd, grad)
1167
+
1168
+ lr = self.lr_params
1169
+ if self.grad_clip_max_norm < float("inf"):
1170
+ grad_norm = torch.linalg.vector_norm(pre_grad)
1171
+ if grad_norm > self.grad_clip_max_norm:
1172
+ lr = lr * self.grad_clip_max_norm / grad_norm
1173
+
1174
+ # update the parameters
1175
+ [param.subtract_(lr * pre_grad[j - i:j].view_as(param))
1176
+ for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]
1177
+
1178
+ # return whatever closure returns
1179
+ return closure_returns
1180
+
1181
+
1182
+ ############# End of PSGD LRA preconditioners #############
1183
+
1184
+
1185
+ ############# Begin of PSGD dense matrix Newton-type preconditioner #############
1186
+
1187
+
1188
+ def update_precond_dense_eq(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
1189
+ """
1190
+ Update dense matrix Newton-type preconditioner Q with local coordinate dQ = mathcal{E} * Q.
1191
+ """
1192
+ a = Q.mm(h + damping*torch.randn_like(h))
1193
+ b = torch.linalg.solve_triangular(lift2single(Q.t()), lift2single(v), upper=False).to(v.dtype)
1194
+ ell = torch.sum(a*a + b*b)
1195
+ L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
1196
+ Q.sub_(lr/L * torch.triu(a.mm(a.t()) - b.mm(b.t())) @ Q)
1197
+
1198
+
1199
+ def update_precond_dense_qep(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
1200
+ """
1201
+ Update dense matrix Newton-type preconditioner Q with local coordinate dQ = Q * mathcal{E} * P.
1202
+ """
1203
+ a = Q @ (Q.T @ (Q @ (h + damping*torch.randn_like(h))))
1204
+ b = Q @ v
1205
+ ell = torch.sum(a*a + b*b)
1206
+ L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
1207
+ Q.sub_(lr/L * (a @ (a.T @ Q) - b @ (b.T @ Q)))
1208
+
1209
+
1210
+ def update_precond_dense_qeq(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
1211
+ """
1212
+ Update dense matrix Newton-type preconditioner Q with local coordinate dQ = Q * mathcal{E} * Q.
1213
+ """
1214
+ a = Q.T @ (Q @ (h + damping*torch.randn_like(h)))
1215
+ ell = torch.sum(a*a + v*v)
1216
+ L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
1217
+ Q.sub_(lr/L * ((Q @ a) @ a.T - (Q @ v) @ v.T))
1218
+
1219
+
1220
+ def update_precond_dense_q0p5eq1p5(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
1221
+ """
1222
+ Update dense matrix Newton-type preconditioner Q with local coordinate dQ = Q^0.5 * mathcal{E} * Q^1.5.
1223
+ """
1224
+ a = Q.T @ (Q @ (h + damping*torch.randn_like(h)))
1225
+ ell = torch.sum(a*a + v*v)
1226
+ L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
1227
+ Q.sub_(lr/L * (a @ (a.T @ Q) - v @ (v.T @ Q)))
1228
+ procrustes_step(Q)
1229
+
1230
+
1231
+ def update_precond_dense_quad(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
1232
+ """
1233
+ Update dense matrix Newton-type preconditioner Q with a quadratic form for dQ.
1234
+ """
1235
+ a = Q @ (Q @ (h + damping*torch.randn_like(h))) # Q is symmetric here
1236
+ ell = torch.sum(a*a + v*v)
1237
+ L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
1238
+ p = Q - lr/2/L * (a @ (a.T @ Q) - v @ (v.T @ Q))
1239
+ p = p - lr/2/L * ((p @ a) @ a.T - (p @ v) @ v.T)
1240
+ Q.copy_((p + p.T)/2)
1241
+
1242
+
1243
+ def update_precond_dense_quad4p(Q, L, v, h, lr=0.1, betaL=0.9, damping=1e-9):
1244
+ """
1245
+ The only case that fits P directly.
1246
+ """
1247
+ a = Q @ (h + damping*torch.randn_like(h)) # Q actually is P; so just apply it once.
1248
+ ell = torch.sum(a*a + v*v)
1249
+ L.copy_(torch.max(betaL*L + (1 - betaL)*ell, ell))
1250
+ p = Q - lr/L * (a @ (a.T @ Q) - v @ (v.T @ Q))
1251
+ p = p - lr/L * ((p @ a) @ a.T - (p @ v) @ v.T)
1252
+ Q.copy_((p + p.T)/2)
1253
+
1254
+
1255
+ class DenseNewton:
1256
+ """
1257
+ Implements the PSGD dense matrix Newton-type preconditioner as a class.
1258
+ Be extra cautious when using the finite difference method for Hvp approximation (the closure must behave like a pure function).
1259
+ It's mainly for illustrating how PSGD works due to its simplicity.
1260
+ It's also a good alternative to the BFGS like quasi-Newton methods as no line search is required.
1261
+ """
1262
+ def __init__(self, params_with_grad, preconditioner_init_scale:float|None=None,
1263
+ lr_params=0.01, lr_preconditioner=0.1, betaL=0.9, damping=1e-9, momentum=0.0,
1264
+ grad_clip_max_norm=float("inf"), preconditioner_update_probability=1.0,
1265
+ exact_hessian_vector_product=True, dQ="Q0.5EQ1.5"):
1266
+ # mutable members
1267
+ self.lr_params = lr_params
1268
+ self.lr_preconditioner = lr_preconditioner
1269
+ self.betaL = betaL # set to a large betaL for sparse Hvp
1270
+ self.damping = damping # to damp and upper bound the preconditioner as P < eye/damping
1271
+ self.momentum = momentum if (0<momentum<1) else 0.0
1272
+ self.grad_clip_max_norm = grad_clip_max_norm
1273
+ self.preconditioner_update_probability = preconditioner_update_probability
1274
+ # protected members
1275
+ params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
1276
+ self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
1277
+ dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device
1278
+ self._delta_param_scale = torch.finfo(dtype).eps ** 0.5
1279
+ self._param_sizes = [torch.numel(param) for param in self._params_with_grad]
1280
+ self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)
1281
+ num_params = self._param_cumsizes[-1]
1282
+ if preconditioner_init_scale is None: # initialize Q on the fly
1283
+ self._Q = None
1284
+ else:
1285
+ if dQ == "QUAD4P": # Q actually is P
1286
+ preconditioner_init_scale *= preconditioner_init_scale
1287
+ self._Q = torch.eye(num_params, dtype=dtype, device=device) * preconditioner_init_scale
1288
+ self._L = lift2single(torch.zeros([], dtype=dtype, device=device)) # Lipschitz smoothness constant estimation for the psgd criterion
1289
+ self._m, self._counter_m = None, 0 # buffer and counter for momentum
1290
+ self._exact_hessian_vector_product = exact_hessian_vector_product
1291
+ if not exact_hessian_vector_product:
1292
+ print("FYI: Approximate Hvp with finite-difference method. Make sure that: 1) the closure behaves like a pure function; 2) delta param scale is proper.")
1293
+ self._dQ = dQ
1294
+ if dQ == "QUAD4P": # the only case that we fit P directly
1295
+ self._update_precond = update_precond_dense_quad4p
1296
+ self._precond_grad = lambda Q, g: Q @ g
1297
+ assert torch.finfo(dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
1298
+ elif dQ == "QUAD":
1299
+ self._update_precond = update_precond_dense_quad
1300
+ self._precond_grad = lambda Q, g: Q @ (Q @ g) # Q is symmetric; just save one transpose
1301
+ else:
1302
+ self._precond_grad = lambda Q, g: Q.T @ (Q @ g)
1303
+ if dQ == "QEP":
1304
+ self._update_precond = update_precond_dense_qep
1305
+ elif dQ == "EQ":
1306
+ self._update_precond = update_precond_dense_eq
1307
+ elif dQ == "QEQ":
1308
+ self._update_precond = update_precond_dense_qeq
1309
+ else:
1310
+ assert (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"), "Invalid choice for dQ"
1311
+ self._update_precond = update_precond_dense_q0p5eq1p5
1312
+
1313
+
1314
+ @torch.no_grad()
1315
+ def step(self, closure):
1316
+ """
1317
+ Performs one step of PSGD with the dense matrix Newton-type preconditioner.
1318
+ """
1319
+ if (torch.rand([]) < self.preconditioner_update_probability) or (self._Q is None):
1320
+ # evaluates gradients, Hessian-vector product, and updates the preconditioner
1321
+ if self._exact_hessian_vector_product: # exact Hessian-vector product
1322
+ with torch.enable_grad():
1323
+ closure_returns = closure()
1324
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
1325
+ grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)
1326
+ vs = [torch.randn_like(param) for param in self._params_with_grad]
1327
+ Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)
1328
+ else: # approximate Hessian-vector product via finite-difference formulae. Use it with cautions.
1329
+ with torch.enable_grad():
1330
+ closure_returns = closure()
1331
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
1332
+ grads = torch.autograd.grad(loss, self._params_with_grad)
1333
+
1334
+ vs = [torch.randn_like(param) for param in self._params_with_grad]
1335
+ [param.add_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
1336
+ with torch.enable_grad():
1337
+ perturbed_returns = closure()
1338
+ perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
1339
+ perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)
1340
+ Hvs = [(perturbed_g - g)/self._delta_param_scale for (perturbed_g, g) in zip(perturbed_grads, grads)]
1341
+ [param.sub_(v, alpha=self._delta_param_scale) for (param, v) in zip(self._params_with_grad, vs)]
1342
+
1343
+ v = torch.cat([torch.reshape(v, [-1, 1]) for v in vs])
1344
+ h = torch.cat([torch.reshape(h, [-1, 1]) for h in Hvs])
1345
+ if self._Q is None: # initialize Q on the fly if it is None
1346
+ scale = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + self.damping**4)**(-1/8)
1347
+ if self._dQ == "QUAD4P": # Q actually is P in this case
1348
+ scale *= scale
1349
+ self._Q = torch.eye(len(v), dtype=v.dtype, device=v.device) * scale
1350
+
1351
+ # update preconditioner
1352
+ self._update_precond(self._Q, self._L, v, h, lr=self.lr_preconditioner, betaL=self.betaL, damping=self.damping)
1353
+ else: # only evaluates the gradients
1354
+ with torch.enable_grad():
1355
+ closure_returns = closure()
1356
+ loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
1357
+ grads = torch.autograd.grad(loss, self._params_with_grad)
1358
+
1359
+ # cat grads
1360
+ grad = torch.cat([torch.reshape(g, [-1, 1]) for g in grads])
1361
+
1362
+ if self.momentum > 0: # precondition momentum
1363
+ beta = min(self._counter_m/(1 + self._counter_m), self.momentum)
1364
+ self._counter_m += 1
1365
+ if self._m is None:
1366
+ self._m = torch.zeros_like(grad)
1367
+
1368
+ self._m.mul_(beta).add_(grad, alpha=1 - beta)
1369
+ pre_grad = self._precond_grad(self._Q, self._m)
1370
+ else:
1371
+ self._m, self._counter_m = None, 0 # clear the buffer and counter when momentum is set to zero
1372
+ pre_grad = self._precond_grad(self._Q, grad)
1373
+
1374
+ lr = self.lr_params
1375
+ if self.grad_clip_max_norm < float("inf"):
1376
+ grad_norm = torch.linalg.vector_norm(pre_grad)
1377
+ if grad_norm > self.grad_clip_max_norm:
1378
+ lr = lr * self.grad_clip_max_norm / grad_norm
1379
+
1380
+ # update the parameters
1381
+ [param.subtract_(lr * pre_grad[j - i:j].view_as(param))
1382
+ for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]
1383
+
1384
+ # return whatever closure returns
1385
+ return closure_returns
1386
+
1387
+
1388
+ ############# End of PSGD dense matrix Newton-type preconditioner #############
1389
+
1390
+ """ end of psgd """