torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,290 @@
1
+ from operator import itemgetter
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Transform, apply
6
+ from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+
8
+ @torch.no_grad
9
+ def update_soap_covariances_(
10
+ grad: torch.Tensor,
11
+ GGs_: list[torch.Tensor | None],
12
+ beta: float | None,
13
+ ):
14
+ for i, GG in enumerate(GGs_):
15
+ if GG is None: continue
16
+
17
+ axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
18
+ if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
19
+ else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
20
+
21
+ @torch.no_grad
22
+ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
23
+ """
24
+ Projects the gradient to the eigenbases of the preconditioner.
25
+ """
26
+ for mat in Q:
27
+ if mat is None: continue
28
+ if len(mat) > 0:
29
+ tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
+ else:
31
+ # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
32
+ permute_order = list(range(1, len(tensors.shape))) + [0]
33
+ tensors = tensors.permute(permute_order)
34
+
35
+ return tensors
36
+
37
+ @torch.no_grad
38
+ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
39
+ """
40
+ Projects the gradient back to the original space.
41
+ """
42
+ for mat in Q:
43
+ if mat is None: continue
44
+ if len(mat) > 0:
45
+ tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
46
+ else:
47
+ permute_order = list(range(1, len(tensors.shape))) + [0]
48
+ tensors = tensors.permute(permute_order)
49
+
50
+ return tensors
51
+
52
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
53
+ @torch.no_grad
54
+ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
55
+ """
56
+ Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
57
+ """
58
+ matrix = []
59
+ float_data = False
60
+ original_type = original_device = None
61
+ for m in mat:
62
+ if m is None: continue
63
+ if len(m) == 0:
64
+ matrix.append([])
65
+ continue
66
+ if m.dtype != torch.float:
67
+ original_type = m.dtype
68
+ original_device = m.device
69
+ matrix.append(m.float())
70
+ else:
71
+ float_data = True
72
+ matrix.append(m)
73
+
74
+ final = []
75
+ for m in matrix:
76
+ if len(m) == 0:
77
+ final.append([])
78
+ continue
79
+ try:
80
+ _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
81
+ except Exception:
82
+ _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
83
+ Q = Q.to(m.dtype)
84
+ Q = torch.flip(Q, [1])
85
+
86
+ if not float_data:
87
+ Q = Q.to(original_device).type(original_type)
88
+ final.append(Q)
89
+ return final
90
+
91
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
92
+ @torch.no_grad
93
+ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
94
+ """
95
+ Computes the eigenbases of the preconditioner using one round of power iteration
96
+ followed by torch.linalg.qr decomposition.
97
+ """
98
+ matrix = []
99
+ orth_matrix = []
100
+ float_data = False
101
+ original_type = original_device = None
102
+ for m,o in zip(GG, Q_list):
103
+ if m is None: continue
104
+ assert o is not None
105
+
106
+ if len(m) == 0:
107
+ matrix.append([])
108
+ orth_matrix.append([])
109
+ continue
110
+ if m.data.dtype != torch.float:
111
+ original_type = m.data.dtype
112
+ original_device = m.data.device
113
+ matrix.append(m.data.float())
114
+ orth_matrix.append(o.data.float())
115
+ else:
116
+ float_data = True
117
+ matrix.append(m.data.float())
118
+ orth_matrix.append(o.data.float())
119
+
120
+ final = []
121
+ for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
122
+ if len(m)==0:
123
+ final.append([])
124
+ continue
125
+ est_eig = torch.diag(o.T @ m @ o)
126
+ sort_idx = torch.argsort(est_eig, descending=True)
127
+ exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
128
+ o = o[:,sort_idx]
129
+ power_iter = m @ o
130
+ Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
131
+
132
+ if not float_data:
133
+ Q = Q.to(original_device).type(original_type)
134
+ final.append(Q)
135
+
136
+ return final, exp_avg_sq
137
+
138
+ class DSOAP(Transform):
139
+ """SOAP but uses scaled gradient differences
140
+
141
+ new args
142
+
143
+ scale by s whether to scale gradient differences by parameter differences
144
+
145
+ y_to_ema2 whether to use gradient differences for exponential moving average too
146
+ """
147
+ def __init__(
148
+ self,
149
+ beta1: float = 0.95,
150
+ beta2: float = 0.95,
151
+ shampoo_beta: float | None = 0.95,
152
+ precond_freq: int = 10,
153
+ merge_small: bool = True,
154
+ max_dim: int = 2_000,
155
+ precondition_1d: bool = True,
156
+ eps: float = 1e-8,
157
+ decay: float | None = None,
158
+ alpha: float = 1,
159
+ bias_correction: bool = True,
160
+ scale_by_s: bool = True,
161
+ y_to_ema2: bool = False,
162
+ ):
163
+ defaults = dict(
164
+ beta1=beta1,
165
+ beta2=beta2,
166
+ shampoo_beta=shampoo_beta,
167
+ precond_freq=precond_freq,
168
+ merge_small=merge_small,
169
+ max_dim=max_dim,
170
+ precondition_1d=precondition_1d,
171
+ eps=eps,
172
+ decay=decay,
173
+ bias_correction=bias_correction,
174
+ alpha=alpha,
175
+ scale_by_s=scale_by_s,
176
+ y_to_ema2=y_to_ema2,
177
+ )
178
+ super().__init__(defaults, uses_grad=False)
179
+
180
+ @torch.no_grad
181
+ def transform(self, tensors, params, grads, vars):
182
+ updates = []
183
+ # update preconditioners
184
+ for i,(p,t) in enumerate(zip(params, tensors)):
185
+ state = self.state[p]
186
+ settings = self.settings[p]
187
+ beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
188
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
189
+ scale_by_s = settings['scale_by_s']
190
+ y_to_ema2 = settings['y_to_ema2']
191
+
192
+ if merge_small:
193
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
194
+
195
+ if 'g_prev' not in state:
196
+ state['p_prev'] = p.clone()
197
+ state['g_prev'] = t.clone()
198
+ updates.append(tensors[i].sign())
199
+ continue
200
+
201
+ p_prev = state['p_prev']
202
+ g_prev = state['g_prev']
203
+ s = p - p_prev
204
+ y = t - g_prev
205
+ if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
206
+
207
+ state['p_prev'].copy_(p)
208
+ state['g_prev'].copy_(t)
209
+
210
+ # initialize state on 1st step
211
+ if 'GG' not in state:
212
+ state["exp_avg"] = torch.zeros_like(t)
213
+ if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
214
+ else: state["exp_avg_sq"] = torch.zeros_like(t)
215
+
216
+ if not precondition_1d and t.ndim <= 1:
217
+ state['GG'] = []
218
+
219
+ else:
220
+ state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
221
+
222
+ # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
223
+ if len([i is not None for i in state['GG']]) == 0:
224
+ state['GG'] = None
225
+
226
+ if state['GG'] is not None:
227
+ update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
228
+ state['Q'] = get_orthogonal_matrix(state['GG'])
229
+
230
+ state['step'] = 0
231
+ updates.append(tensors[i].sign())
232
+ continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
233
+ # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
234
+
235
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
236
+ # i.e. projecting to the eigenbases of matrices in state['GG']
237
+ z_projected = None
238
+ if state['GG'] is not None:
239
+ if y_to_ema2: z_projected = project(y, state['Q'])
240
+ else: z_projected = project(t, state['Q'])
241
+
242
+ # exponential moving averages
243
+ # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
244
+ exp_avg: torch.Tensor = state["exp_avg"]
245
+ exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
246
+
247
+ exp_avg.lerp_(t, 1-beta1)
248
+
249
+ if z_projected is None:
250
+ if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
251
+ else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
252
+ else:
253
+ exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
254
+
255
+ # project exponential moving averages if they are accumulated unprojected
256
+ exp_avg_projected = exp_avg
257
+ if z_projected is not None:
258
+ exp_avg_projected = project(exp_avg, state['Q'])
259
+
260
+ exp_avg_sq_projected = exp_avg_sq
261
+
262
+ denom = exp_avg_sq_projected.sqrt().add_(eps)
263
+ # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
264
+
265
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
266
+ # to the original space
267
+ update = exp_avg_projected / denom
268
+ if z_projected is not None:
269
+ update = project_back(update, state["Q"])
270
+
271
+ if settings['bias_correction']:
272
+ bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
273
+ bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
274
+ update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
275
+ elif alpha is not None:
276
+ update *= alpha
277
+
278
+ if merge_small:
279
+ update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
280
+
281
+ updates.append(update)
282
+ state["step"] += 1
283
+
284
+ # Update is done after the gradient step to avoid using current gradients in the projection.
285
+ if state['GG'] is not None:
286
+ update_soap_covariances_(y, state['GG'], shampoo_beta)
287
+ if state['step'] % settings['precond_freq'] == 0:
288
+ state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
289
+
290
+ return updates
@@ -0,0 +1,85 @@
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable, Sequence
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Module, Vars
9
+ from ...utils import NumberList, TensorList
10
+ from ...utils.derivatives import jacobian_wrt
11
+ from ..grad_approximation import GradApproximator, GradTarget
12
+ from ..smoothing.gaussian import Reformulation
13
+
14
+
15
+
16
+ class GradMin(Reformulation):
17
+ """Reformulates the objective to minimize sum of gradient magnitudes via autograd.
18
+
19
+ Args:
20
+ loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
21
+ relative (bool, optional): whether to make loss_term relative to gradient magnitude. Defaults to False.
22
+ graft (bool, optional): whether to make loss term same as gradient magnitude. Defaults to False.
23
+ square (bool, optional): whether to use sum of squared gradient magnitudes, if False uses absolute values. Defaults to False.
24
+ mean (bool, optional): whether to use mean, if False uses sum. Defaults to True.
25
+ maximize_grad (bool, optional): whether to maximize gradient magnitudes instead of minimizing. Defaults to False.
26
+ create_graph (bool, optional): whether to create graph. Defaults to False.
27
+ modify_loss (bool, optional): whether to modify the loss value to make line searches minimize new objective. Defaults to True.
28
+ """
29
+ def __init__(
30
+ self,
31
+ loss_term: float | None = 0,
32
+ relative: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
33
+ graft: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
34
+ square=False,
35
+ mean=True,
36
+ maximize_grad=False,
37
+ create_graph=False,
38
+ modify_loss: bool = True,
39
+ ):
40
+ if (relative is not None) and (graft is not None): warnings.warn('both relative and graft loss are True, they will clash with each other')
41
+ defaults = dict(loss_term=loss_term, relative=relative, graft=graft, square=square, mean=mean, maximize_grad=maximize_grad, create_graph=create_graph, modify_loss=modify_loss)
42
+ super().__init__(defaults)
43
+
44
+ @torch.no_grad
45
+ def closure(self, backward, closure, params, vars):
46
+ settings = self.settings[params[0]]
47
+ loss_term = settings['loss_term']
48
+ relative = settings['relative']
49
+ graft = settings['graft']
50
+ square = settings['square']
51
+ maximize_grad = settings['maximize_grad']
52
+ create_graph = settings['create_graph']
53
+ modify_loss = settings['modify_loss']
54
+ mean = settings['mean']
55
+
56
+ with torch.enable_grad():
57
+ for p in params: p.grad = None
58
+ loss = closure(False)
59
+ grads = TensorList(torch.autograd.grad(loss, params, create_graph=True))
60
+
61
+ if square: grads = grads ** 2
62
+ else: grads = grads.abs()
63
+
64
+ if mean: f = grads.global_mean()
65
+ else: f = grads.global_sum()
66
+
67
+
68
+ if graft == 'grad_to_loss': f = f * (loss.detach()/f.detach()).detach()
69
+ if relative == 'grad_to_loss': f = f * loss
70
+
71
+ if loss_term is not None and loss_term != 0:
72
+ if relative == 'loss_to_grad': loss_term = loss_term * f
73
+ l = loss
74
+ if graft == 'loss_to_grad': l = loss * (f.detach()/loss.detach()).detach()
75
+ f = f + l*loss_term
76
+
77
+ if maximize_grad: f = -f
78
+ if modify_loss: loss = f
79
+
80
+ grad = None
81
+ if backward:
82
+ for p in params: p.grad = None
83
+ grad = TensorList(torch.autograd.grad(f, params, create_graph=create_graph))
84
+
85
+ return loss, grad
@@ -0,0 +1,35 @@
1
+ import torch
2
+
3
+ from ...core import Target, Transform
4
+ from ...utils import TensorList
5
+
6
+ class ReduceOutwardLR(Transform):
7
+ """
8
+ When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
9
+
10
+ This means updates that move weights towards zero have higher learning rates.
11
+ """
12
+ def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
13
+ defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
14
+ super().__init__(defaults, uses_grad=use_grad, target=target)
15
+
16
+ @torch.no_grad
17
+ def transform(self, tensors, params, grads, vars):
18
+ params = TensorList(params)
19
+ tensors = TensorList(tensors)
20
+
21
+ mul = self.get_settings('mul', params=params)
22
+ s = self.settings[params[0]]
23
+ use_grad = s['use_grad']
24
+ invert = s['invert']
25
+
26
+ if use_grad: cur = vars.get_grad()
27
+ else: cur = tensors
28
+
29
+ # mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
30
+ if invert: mask = (params * cur) > 0
31
+ else: mask = (params * cur) < 0
32
+
33
+ tensors.masked_set_(mask, tensors*mul)
34
+
35
+ return tensors
@@ -0,0 +1,286 @@
1
+ from abc import ABC, abstractmethod
2
+ import math
3
+ from collections import deque
4
+ from typing import Literal, Any
5
+
6
+ import torch
7
+ from ...core import Chainable, TensorwisePreconditioner
8
+ from ...utils.linalg.matrix_funcs import matrix_power_eigh
9
+ from ...utils.linalg.svd import randomized_svd
10
+ from ...utils.linalg.qr import qr_householder
11
+
12
+
13
+ class _Solver:
14
+ @abstractmethod
15
+ def update(self, history: deque[torch.Tensor], damping: float | None) -> tuple[Any, Any]:
16
+ """returns stuff for apply"""
17
+ @abstractmethod
18
+ def apply(self, __g: torch.Tensor, __A:torch.Tensor, __B:torch.Tensor) -> torch.Tensor:
19
+ """apply preconditioning to tensor"""
20
+
21
+ class _SVDSolver(_Solver):
22
+ def __init__(self, driver=None): self.driver=driver
23
+ def update(self, history, damping):
24
+ M_hist = torch.stack(tuple(history), dim=1)
25
+ device = None # driver is CUDA only
26
+ if self.driver is not None:
27
+ device = M_hist.device
28
+ M_hist = M_hist.cuda()
29
+
30
+ try:
31
+ U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver=self.driver) # pylint:disable=not-callable
32
+
33
+ if self.driver is not None:
34
+ U = U.to(device); S = S.to(device)
35
+
36
+ if damping is not None and damping != 0: S.add_(damping)
37
+ return U, S
38
+
39
+ except torch.linalg.LinAlgError:
40
+ return None, None
41
+
42
+ def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
43
+ Utg = (U.T @ g).div_(S)
44
+ return U @ Utg
45
+
46
+ class _SVDLowRankSolver(_Solver):
47
+ def __init__(self, q: int = 6, niter: int = 2): self.q, self.niter = q, niter
48
+ def update(self, history, damping):
49
+ M_hist = torch.stack(tuple(history), dim=1)
50
+ try:
51
+ U, S, _ = torch.svd_lowrank(M_hist, q=self.q, niter=self.niter)
52
+ if damping is not None and damping != 0: S.add_(damping)
53
+ return U, S
54
+ except torch.linalg.LinAlgError:
55
+ return None, None
56
+
57
+ def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
58
+ Utg = (U.T @ g).div_(S)
59
+ return U @ Utg
60
+
61
+ class _RandomizedSVDSolver(_Solver):
62
+ def __init__(self, k: int = 3, driver: str | None = 'gesvda'):
63
+ self.driver = driver
64
+ self.k = k
65
+
66
+ def update(self, history, damping):
67
+ M_hist = torch.stack(tuple(history), dim=1)
68
+ device = None # driver is CUDA only
69
+ if self.driver is not None:
70
+ device = M_hist.device
71
+ M_hist = M_hist.cuda()
72
+
73
+ try:
74
+ U, S, _ = randomized_svd(M_hist, k=self.k, driver=self.driver)
75
+
76
+ if self.driver is not None:
77
+ U = U.to(device); S = S.to(device)
78
+
79
+ if damping is not None and damping != 0: S.add_(damping)
80
+ return U, S
81
+
82
+ except torch.linalg.LinAlgError:
83
+ return None, None
84
+
85
+ def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
86
+ Utg = (U.T @ g).div_(S)
87
+ return U @ Utg
88
+
89
+ class _QRDiagonalSolver(_Solver):
90
+ def __init__(self, sqrt=True): self.sqrt = sqrt
91
+ def update(self, history, damping):
92
+ M_hist = torch.stack(tuple(history), dim=1)
93
+ try:
94
+ Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
95
+ R_diag = R.diag().abs()
96
+ if damping is not None and damping != 0: R_diag.add_(damping)
97
+ if self.sqrt: R_diag.sqrt_()
98
+ return Q, R_diag
99
+ except torch.linalg.LinAlgError:
100
+ return None, None
101
+
102
+ def apply(self, g: torch.Tensor, Q: torch.Tensor, R_diag: torch.Tensor):
103
+ Qtg = (Q.T @ g).div_(R_diag)
104
+ return Q @ Qtg
105
+
106
+ class _QRSolver(_Solver):
107
+ def __init__(self, sqrt=True): self.sqrt = sqrt
108
+ def update(self, history, damping):
109
+ M_hist = torch.stack(tuple(history), dim=1)
110
+ try:
111
+ # Q: d x k, R: k x k
112
+ Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
113
+ A = R @ R.T
114
+ if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
115
+ if self.sqrt: A = matrix_power_eigh(A, 0.5)
116
+ return Q, A
117
+ except (torch.linalg.LinAlgError):
118
+ return None,None
119
+
120
+ def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
121
+ g_proj = Q.T @ g
122
+ y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
123
+ return Q @ y
124
+
125
+ class _QRHouseholderSolver(_Solver):
126
+ def __init__(self, sqrt=True): self.sqrt = sqrt
127
+ def update(self, history, damping):
128
+ M_hist = torch.stack(tuple(history), dim=1)
129
+ try:
130
+ # Q: d x k, R: k x k
131
+ Q, R = qr_householder(M_hist, mode='reduced') # pylint:disable=not-callable
132
+ A = R @ R.T
133
+ if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
134
+ if self.sqrt: A = matrix_power_eigh(A, 0.5)
135
+ return Q, A
136
+ except (torch.linalg.LinAlgError):
137
+ return None,None
138
+
139
+ def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
140
+ g_proj = Q.T @ g
141
+ y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
142
+ return Q @ y
143
+
144
+
145
+ class _EighSolver(_Solver):
146
+ def __init__(self, sqrt=True):
147
+ self.sqrt = sqrt
148
+
149
+ def update(self, history, damping):
150
+ M_hist = torch.stack(tuple(history), dim=1)
151
+ grams = M_hist @ M_hist.T # (d, d)
152
+ if damping is not None and damping != 0: grams.diagonal(dim1=-2, dim2=-1).add_(damping)
153
+ try:
154
+ L, Q = torch.linalg.eigh(grams) # L: (d,), Q: (d, d) # pylint:disable=not-callable
155
+ L = L.abs().clamp_(min=1e-12)
156
+ if self.sqrt: L = L.sqrt()
157
+ return Q, L
158
+ except torch.linalg.LinAlgError:
159
+ return None, None
160
+
161
+ def apply(self, g: torch.Tensor, Q: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
162
+ Qtg = (Q.T @ g).div_(L)
163
+ return Q @ Qtg
164
+
165
+
166
+ SOLVERS = {
167
+ "svd": _SVDSolver(), # fallbacks on "gesvd" which basically takes ages or just hangs completely
168
+ "svd_gesvdj": _SVDSolver("gesvdj"), # no fallback on slow "gesvd"
169
+ "svd_gesvda": _SVDSolver("gesvda"), # approximate method for wide matrices, sometimes better sometimes worse but faster
170
+ "svd_lowrank": _SVDLowRankSolver(), # maybe need to tune parameters for this, with current ones its slower and worse
171
+ "randomized_svd2": _RandomizedSVDSolver(2),
172
+ "randomized_svd3": _RandomizedSVDSolver(3),
173
+ "randomized_svd4": _RandomizedSVDSolver(4),
174
+ "randomized_svd5": _RandomizedSVDSolver(5),
175
+ "eigh": _EighSolver(), # this is O(n**2) storage, but is this more accurate?
176
+ "qr": _QRSolver(),
177
+ "qr_householder": _QRHouseholderSolver(), # this is slower... but maybe it won't freeze? I think svd_gesvda is better
178
+ "qrdiag": _QRDiagonalSolver(),
179
+ }
180
+
181
+ def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
182
+ if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
183
+ else:
184
+ if state_[key].shape != value.shape: state_[key] = value
185
+ else: state_[key].lerp_(value, 1-beta)
186
+
187
+ class SpectralPreconditioner(TensorwisePreconditioner):
188
+ """Whitening preconditioner via SVD on history of past gradients or gradient differences scaled by parameter differences.
189
+
190
+ Args:
191
+ history_size (int, optional): number of past gradients to store for preconditioning. Defaults to 10.
192
+ update_freq (int, optional): how often to re-compute the preconditioner. Defaults to 1.
193
+ damping (float, optional): damping term, makes it closer to GD. Defaults to 1e-7.
194
+ order (int, optional):
195
+ whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
196
+ solver (str, optional): what to use for whitening. Defaults to 'svd'.
197
+ U_beta (float | None, optional): beta for U (probably a bad idea). Defaults to None.
198
+ S_beta (float | None, optional): beta for S (probably a bad idea). Defaults to None.
199
+ interval (int, optional): How often to update history. Defaults to 1 (every step).
200
+ concat_params (bool, optional):
201
+ whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
202
+ scale_first (bool, optional): makes first step small, usually not needed. Defaults to False.
203
+ inner (Chainable | None, optional): Inner modules applied after updating preconditioner and before applying it. Defaults to None.
204
+ """
205
+ def __init__(
206
+ self,
207
+ history_size: int = 10,
208
+ update_freq: int = 1,
209
+ damping: float = 1e-12,
210
+ order: int = 1,
211
+ solver: Literal['svd', 'svd_gesvdj', 'svd_gesvda', 'svd_lowrank', 'eigh', 'qr', 'qrdiag', 'qr_householder'] | _Solver | str = 'svd_gesvda',
212
+ A_beta: float | None = None,
213
+ B_beta: float | None = None,
214
+ interval: int = 1,
215
+ concat_params: bool = False,
216
+ scale_first: bool = False,
217
+ inner: Chainable | None = None,
218
+ ):
219
+ if isinstance(solver, str): solver = SOLVERS[solver]
220
+ # history is still updated each step so Precondition's update_freq has different meaning
221
+ defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, order=order, A_beta=A_beta, B_beta=B_beta, solver=solver)
222
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, scale_first=scale_first, inner=inner, update_freq=interval)
223
+
224
+ @torch.no_grad
225
+ def update_tensor(self, tensor, param, grad, state, settings):
226
+ order = settings['order']
227
+ history_size = settings['history_size']
228
+ update_freq = settings['update_freq']
229
+ damping = settings['damping']
230
+ A_beta = settings['A_beta']
231
+ B_beta = settings['B_beta']
232
+ solver: _Solver = settings['solver']
233
+
234
+ if 'history' not in state: state['history'] = deque(maxlen=history_size)
235
+ history = state['history']
236
+
237
+ if order == 1: history.append(tensor.clone().view(-1))
238
+ else:
239
+
240
+ # if order=2, history is of gradient differences, order 3 is differences between differences, etc
241
+ # normalized by parameter differences
242
+ cur_p = param.clone()
243
+ cur_g = tensor.clone()
244
+ for i in range(1, order):
245
+ if f'prev_g_{i}' not in state:
246
+ state[f'prev_p_{i}'] = cur_p
247
+ state[f'prev_g_{i}'] = cur_g
248
+ break
249
+
250
+ s_k = cur_p - state[f'prev_p_{i}']
251
+ y_k = cur_g - state[f'prev_g_{i}']
252
+ state[f'prev_p_{i}'] = cur_p
253
+ state[f'prev_g_{i}'] = cur_g
254
+ cur_p = s_k
255
+ cur_g = y_k
256
+
257
+ if i == order - 1:
258
+ cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
259
+ history.append(cur_g.view(-1))
260
+
261
+ step = state.get('step', 0)
262
+ if step % update_freq == 0 and len(history) != 0:
263
+ A, B = solver.update(history, damping=damping)
264
+ maybe_lerp_(state, A_beta, 'A', A)
265
+ maybe_lerp_(state, B_beta, 'B', B)
266
+
267
+ if len(history) != 0:
268
+ state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
269
+
270
+ @torch.no_grad
271
+ def apply_tensor(self, tensor, param, grad, state, settings):
272
+ history_size = settings['history_size']
273
+ solver: _Solver = settings['solver']
274
+
275
+ A = state.get('A', None)
276
+ if A is None:
277
+ # make a conservative step to avoid issues due to different GD scaling
278
+ return tensor.div_(max(1, tensor.abs().sum())) # pyright:ignore[reportArgumentType]
279
+
280
+ B = state['B']
281
+ update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
282
+
283
+ n = len(state['history'])
284
+ if n != history_size: update.mul_(n/history_size)
285
+ return update
286
+