torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,286 @@
1
+ from operator import itemgetter
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Transform, apply
6
+ from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+
8
+ @torch.no_grad
9
+ def update_soap_covariances_(
10
+ grad: torch.Tensor,
11
+ GGs_: list[torch.Tensor | None],
12
+ beta: float | None,
13
+ ):
14
+ for i, GG in enumerate(GGs_):
15
+ if GG is None: continue
16
+
17
+ axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
18
+ if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
19
+ else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
20
+
21
+ @torch.no_grad
22
+ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
23
+ """
24
+ Projects the gradient to the eigenbases of the preconditioner.
25
+ """
26
+ for mat in Q:
27
+ if mat is None: continue
28
+ if len(mat) > 0:
29
+ tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
+ else:
31
+ # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
32
+ permute_order = list(range(1, len(tensors.shape))) + [0]
33
+ tensors = tensors.permute(permute_order)
34
+
35
+ return tensors
36
+
37
+ @torch.no_grad
38
+ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
39
+ """
40
+ Projects the gradient back to the original space.
41
+ """
42
+ for mat in Q:
43
+ if mat is None: continue
44
+ if len(mat) > 0:
45
+ tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
46
+ else:
47
+ permute_order = list(range(1, len(tensors.shape))) + [0]
48
+ tensors = tensors.permute(permute_order)
49
+
50
+ return tensors
51
+
52
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
53
+ @torch.no_grad
54
+ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
55
+ """
56
+ Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
57
+ """
58
+ matrix = []
59
+ float_data = False
60
+ original_type = original_device = None
61
+ for m in mat:
62
+ if m is None: continue
63
+ if len(m) == 0:
64
+ matrix.append([])
65
+ continue
66
+ if m.dtype != torch.float:
67
+ original_type = m.dtype
68
+ original_device = m.device
69
+ matrix.append(m.float())
70
+ else:
71
+ float_data = True
72
+ matrix.append(m)
73
+
74
+ final = []
75
+ for m in matrix:
76
+ if len(m) == 0:
77
+ final.append([])
78
+ continue
79
+ try:
80
+ _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
81
+ except Exception:
82
+ _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
83
+ Q = Q.to(m.dtype)
84
+ Q = torch.flip(Q, [1])
85
+
86
+ if not float_data:
87
+ Q = Q.to(original_device).type(original_type)
88
+ final.append(Q)
89
+ return final
90
+
91
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
92
+ @torch.no_grad
93
+ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
94
+ """
95
+ Computes the eigenbases of the preconditioner using one round of power iteration
96
+ followed by torch.linalg.qr decomposition.
97
+ """
98
+ matrix = []
99
+ orth_matrix = []
100
+ float_data = False
101
+ original_type = original_device = None
102
+ for m,o in zip(GG, Q_list):
103
+ if m is None: continue
104
+ assert o is not None
105
+
106
+ if len(m) == 0:
107
+ matrix.append([])
108
+ orth_matrix.append([])
109
+ continue
110
+ if m.data.dtype != torch.float:
111
+ original_type = m.data.dtype
112
+ original_device = m.data.device
113
+ matrix.append(m.data.float())
114
+ orth_matrix.append(o.data.float())
115
+ else:
116
+ float_data = True
117
+ matrix.append(m.data.float())
118
+ orth_matrix.append(o.data.float())
119
+
120
+ final = []
121
+ for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
122
+ if len(m)==0:
123
+ final.append([])
124
+ continue
125
+ est_eig = torch.diag(o.T @ m @ o)
126
+ sort_idx = torch.argsort(est_eig, descending=True)
127
+ exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
128
+ o = o[:,sort_idx]
129
+ power_iter = m @ o
130
+ Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
131
+
132
+ if not float_data:
133
+ Q = Q.to(original_device).type(original_type)
134
+ final.append(Q)
135
+
136
+ return final, exp_avg_sq
137
+
138
+ class SOAP(Transform):
139
+ """SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).
140
+
141
+ Args:
142
+ beta1 (float, optional): beta for first momentum. Defaults to 0.95.
143
+ beta2 (float, optional): beta for second momentum. Defaults to 0.95.
144
+ shampoo_beta (float | None, optional):
145
+ beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
146
+ precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
147
+ merge_small (bool, optional): Whether to merge small dims. Defaults to True.
148
+ max_dim (int, optional): Won't precondition dims larger than this. Defaults to 2_000.
149
+ precondition_1d (bool, optional):
150
+ Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
151
+ eps (float, optional):
152
+ epsilon for dividing first momentum by second. Defaults to 1e-8.
153
+ decay (float | None, optional):
154
+ Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
155
+ unprojected_exp_avg (bool, optional):
156
+ whether to update first momentum in unprojected space. Both true and false work and lead to different
157
+ results but True usually works better. Defaults to True.
158
+ bias_correction (bool, optional):
159
+ enables adam bias correction. Defaults to True.
160
+ """
161
+ def __init__(
162
+ self,
163
+ beta1: float = 0.95,
164
+ beta2: float = 0.95,
165
+ shampoo_beta: float | None = 0.95,
166
+ precond_freq: int = 10,
167
+ merge_small: bool = True,
168
+ max_dim: int = 2_000,
169
+ precondition_1d: bool = True,
170
+ eps: float = 1e-8,
171
+ decay: float | None = None,
172
+ alpha: float = 1,
173
+ unprojected_exp_avg: bool = True,
174
+ bias_correction: bool = True,
175
+ ):
176
+ defaults = dict(
177
+ beta1=beta1,
178
+ beta2=beta2,
179
+ shampoo_beta=shampoo_beta,
180
+ precond_freq=precond_freq,
181
+ merge_small=merge_small,
182
+ max_dim=max_dim,
183
+ precondition_1d=precondition_1d,
184
+ eps=eps,
185
+ decay=decay,
186
+ unprojected_exp_avg=unprojected_exp_avg,
187
+ bias_correction=bias_correction,
188
+ alpha=alpha,
189
+ )
190
+ super().__init__(defaults, uses_grad=False)
191
+
192
+ @torch.no_grad
193
+ def transform(self, tensors, params, grads, vars):
194
+ updates = []
195
+ # update preconditioners
196
+ for i,(p,t) in enumerate(zip(params, tensors)):
197
+ state = self.state[p]
198
+ settings = self.settings[p]
199
+ beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
200
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
201
+
202
+ if merge_small:
203
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
204
+
205
+ # initialize state on 1st step
206
+ if 'GG' not in state:
207
+ state["exp_avg"] = torch.zeros_like(t)
208
+ state["exp_avg_sq"] = torch.zeros_like(t)
209
+
210
+ if not precondition_1d and t.ndim <= 1:
211
+ state['GG'] = []
212
+
213
+ else:
214
+ state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
215
+
216
+ # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
217
+ if len([i is not None for i in state['GG']]) == 0:
218
+ state['GG'] = None
219
+
220
+ if state['GG'] is not None:
221
+ update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
222
+ state['Q'] = get_orthogonal_matrix(state['GG'])
223
+
224
+ state['step'] = 0
225
+ updates.append(tensors[i].sign().div_(10))
226
+ # updates.append(tensors[i] / tensors[i].abs().sum())
227
+ continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
228
+ # I use scaled update instead as to not mess up with next modules.
229
+
230
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
231
+ # i.e. projecting to the eigenbases of matrices in state['GG']
232
+ t_projected = None
233
+ if state['GG'] is not None:
234
+ t_projected = project(t, state['Q'])
235
+
236
+ # exponential moving averages
237
+ # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
238
+ exp_avg: torch.Tensor = state["exp_avg"]
239
+ exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
240
+
241
+ if unprojected_exp_avg or t_projected is None:
242
+ exp_avg.lerp_(t, 1-beta1)
243
+ else:
244
+ exp_avg.lerp_(t_projected, 1-beta1)
245
+
246
+ if t_projected is None:
247
+ exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
248
+ else:
249
+ exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
250
+
251
+ # project exponential moving averages if they are accumulated unprojected
252
+ exp_avg_projected = exp_avg
253
+ if unprojected_exp_avg and t_projected is not None:
254
+ exp_avg_projected = project(exp_avg, state['Q'])
255
+
256
+ exp_avg_sq_projected = exp_avg_sq
257
+
258
+ denom = exp_avg_sq_projected.sqrt().add_(eps)
259
+ # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
260
+
261
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
262
+ # to the original space
263
+ update = exp_avg_projected / denom
264
+ if t_projected is not None:
265
+ update = project_back(update, state["Q"])
266
+
267
+ if settings['bias_correction']:
268
+ bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
269
+ bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
270
+ update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
271
+ elif alpha is not None:
272
+ update *= alpha
273
+
274
+ if merge_small:
275
+ update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
276
+
277
+ updates.append(update)
278
+ state["step"] += 1
279
+
280
+ # Update is done after the gradient step to avoid using current gradients in the projection.
281
+ if state['GG'] is not None:
282
+ update_soap_covariances_(t, state['GG'], shampoo_beta)
283
+ if state['step'] % settings['precond_freq'] == 0:
284
+ state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
285
+
286
+ return updates
@@ -0,0 +1,129 @@
1
+ from typing import Literal
2
+ from collections.abc import Callable
3
+ import torch
4
+
5
+ from ...core import Module, Target, Transform, Chainable, apply
6
+ from ...utils import NumberList, TensorList, as_tensorlist
7
+ from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
8
+
9
+ def sophia_H(
10
+ tensors: TensorList,
11
+ h: TensorList | None,
12
+ exp_avg_: TensorList,
13
+ h_exp_avg_: TensorList,
14
+ beta1: float | NumberList,
15
+ beta2: float | NumberList,
16
+ update_freq: int,
17
+ precond_scale: float | NumberList,
18
+ clip: float | NumberList,
19
+ eps: float | NumberList,
20
+ step: int
21
+ ):
22
+ # momentum
23
+ exp_avg_.lerp_(tensors, 1-beta1)
24
+
25
+ # update preconditioner
26
+ if step % update_freq == 0:
27
+ assert h is not None
28
+ h_exp_avg_.lerp_(h, 1-beta2)
29
+
30
+ else:
31
+ assert h is None
32
+
33
+ denom = (h_exp_avg_ * precond_scale).clip_(min=eps)
34
+ return (exp_avg_ / denom).clip_(-clip, clip)
35
+
36
+
37
+ class SophiaH(Module):
38
+ def __init__(
39
+ self,
40
+ beta1: float = 0.96,
41
+ beta2: float = 0.99,
42
+ update_freq: int = 10,
43
+ precond_scale: float = 1,
44
+ clip: float = 1,
45
+ eps: float = 1e-12,
46
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
47
+ fd_h: float = 1e-3,
48
+ n_samples = 1,
49
+ seed: int | None = None,
50
+ inner: Chainable | None = None
51
+ ):
52
+ defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, precond_scale=precond_scale, clip=clip, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
53
+ super().__init__(defaults)
54
+
55
+ if inner is not None:
56
+ self.set_child('inner', inner)
57
+
58
+ @torch.no_grad
59
+ def step(self, vars):
60
+ params = vars.params
61
+ settings = self.settings[params[0]]
62
+ hvp_method = settings['hvp_method']
63
+ fd_h = settings['fd_h']
64
+ update_freq = settings['update_freq']
65
+ n_samples = settings['n_samples']
66
+
67
+ seed = settings['seed']
68
+ generator = None
69
+ if seed is not None:
70
+ if 'generator' not in self.global_state:
71
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
72
+ generator = self.global_state['generator']
73
+
74
+ beta1, beta2, precond_scale, clip, eps = self.get_settings(
75
+ 'beta1', 'beta2', 'precond_scale', 'clip', 'eps', params=params, cls=NumberList)
76
+
77
+ exp_avg, h_exp_avg = self.get_state('exp_avg', 'h_exp_avg', params=params, cls=TensorList)
78
+
79
+ step = self.global_state.get('step', 0)
80
+ self.global_state['step'] = step + 1
81
+
82
+ closure = vars.closure
83
+ assert closure is not None
84
+
85
+ h = None
86
+ if step % update_freq == 0:
87
+
88
+ grad=None
89
+ for i in range(n_samples):
90
+ u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
91
+
92
+ if hvp_method == 'autograd':
93
+ if grad is None: grad = vars.get_grad(create_graph=True)
94
+ assert grad is not None
95
+ Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
96
+
97
+ elif hvp_method == 'forward':
98
+ loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=vars.get_grad(), normalize=True)
99
+
100
+ elif hvp_method == 'central':
101
+ loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
102
+
103
+ else:
104
+ raise ValueError(hvp_method)
105
+
106
+ if h is None: h = Hvp
107
+ else: torch._foreach_add_(h, Hvp)
108
+
109
+ assert h is not None
110
+ if n_samples > 1: torch._foreach_div_(h, n_samples)
111
+
112
+ update = vars.get_update()
113
+ if 'inner' in self.children:
114
+ update = apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars)
115
+
116
+ vars.update = sophia_H(
117
+ tensors=TensorList(update),
118
+ h=TensorList(h) if h is not None else None,
119
+ exp_avg_=exp_avg,
120
+ h_exp_avg_=h_exp_avg,
121
+ beta1=beta1,
122
+ beta2=beta2,
123
+ update_freq=update_freq,
124
+ precond_scale=precond_scale,
125
+ clip=clip,
126
+ eps=eps,
127
+ step=step,
128
+ )
129
+ return vars
@@ -0,0 +1,5 @@
1
+ from .projection import Projection
2
+ from .fft import FFTProjection
3
+ from .structural import VectorProjection, TensorizeProjection, BlockPartition, TensorNormsProjection
4
+
5
+ # from .galore import GaLore
@@ -0,0 +1,73 @@
1
+ from typing import Literal
2
+ import torch
3
+ import torch_dct
4
+ from .projection import Projection
5
+ from ...core import Chainable
6
+
7
+ def reverse_dims(t:torch.Tensor):
8
+ return t.permute(*reversed(range(t.ndim)))
9
+
10
+ class DCTProjection(Projection):
11
+ # norm description copied from pytorch docstring
12
+ """Project update into Discrete Cosine Transform space, requires `torch_dct` library.
13
+
14
+ Args:
15
+ modules (Chainable): modules that will optimize the projected update.
16
+ dims (1, 2 or 3, optional):
17
+ applies DCT to first 1,2 or 3 dims, defaults to 3.
18
+ norm (str, optional):
19
+ Normalization mode.
20
+ * None - no normalization
21
+ * "ortho" - normalize by 1/sqrt(n)
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ modules: Chainable,
27
+ dims: Literal[1, 2, 3] = 3,
28
+ norm=None,
29
+ project_update=True,
30
+ project_params=False,
31
+ project_grad=False,
32
+ ):
33
+ defaults = dict(dims=dims, norm=norm)
34
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
35
+
36
+ @torch.no_grad
37
+ def project(self, tensors, vars, current):
38
+ settings = self.settings[vars.params[0]]
39
+ dims = settings['dims']
40
+ norm = settings['norm']
41
+
42
+ projected = []
43
+ for u in tensors:
44
+ u = reverse_dims(u)
45
+ dim = min(u.ndim, dims)
46
+
47
+ if dim == 1: dct = torch_dct.dct(u, norm = norm)
48
+ elif dim == 2: dct = torch_dct.dct_2d(u, norm=norm)
49
+ elif dim == 3: dct = torch_dct.dct_3d(u, norm=norm)
50
+ else: raise ValueError(f"Unsupported number of dimensions {dim}")
51
+
52
+ projected.append(dct)
53
+
54
+ return projected
55
+
56
+ @torch.no_grad
57
+ def unproject(self, tensors, vars, current):
58
+ settings = self.settings[vars.params[0]]
59
+ dims = settings['dims']
60
+ norm = settings['norm']
61
+
62
+ unprojected = []
63
+ for u in tensors:
64
+ dim = min(u.ndim, dims)
65
+
66
+ if dim == 1: idct = torch_dct.idct(u, norm = norm)
67
+ elif dim == 2: idct = torch_dct.idct_2d(u, norm=norm)
68
+ elif dim == 3: idct = torch_dct.idct_3d(u, norm=norm)
69
+ else: raise ValueError(f"Unsupported number of dimensions {dim}")
70
+
71
+ unprojected.append(reverse_dims(idct))
72
+
73
+ return unprojected
@@ -0,0 +1,73 @@
1
+ import torch
2
+
3
+ from ...core import Chainable
4
+ from ...utils import vec_to_tensors
5
+ from .projection import Projection
6
+
7
+
8
+ class FFTProjection(Projection):
9
+ # norm description copied from pytorch docstring
10
+ """Project update into Fourrier space of real-valued inputs.
11
+
12
+ Args:
13
+ modules (Chainable): modules that will optimize the projected update.
14
+ one_d (bool, optional):
15
+ * If True, uses 1d fft on parameters concatenated into a vector.
16
+ * If False, uses n-dimensional fft on each parameter (default).
17
+ norm (str, optional):
18
+ Normalization mode.
19
+
20
+ * "forward" - normalize by 1/n
21
+ * "backward" - no normalization
22
+ * "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)
23
+
24
+ Calling the backward transform (:func:`~torch.fft.irfft`) with the same
25
+ normalization mode will apply an overall normalization of ``1/n`` between
26
+ the two transforms. This is required to make :func:`~torch.fft.irfft`
27
+ the exact inverse.
28
+
29
+ Default is "backward" (no normalization).
30
+
31
+ The actual torch.fft.rfft default is None, so I set it to None too. I guess None and "backward"
32
+ are the same.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ modules: Chainable,
38
+ one_d: bool = False,
39
+ norm=None,
40
+ project_update=True,
41
+ project_params=False,
42
+ project_grad=False,
43
+ ):
44
+ defaults = dict(one_d=one_d, norm=norm)
45
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
46
+
47
+ @torch.no_grad
48
+ def project(self, tensors, vars, current):
49
+ settings = self.settings[vars.params[0]]
50
+ one_d = settings['one_d']
51
+ norm = settings['norm']
52
+
53
+ # 1d fft, concatenate all parameters into a vector and calculate fft
54
+ if one_d:
55
+ vec = torch.cat([t.view(-1) for t in tensors])
56
+ self.global_state['length'] = len(vec)
57
+ return [torch.view_as_real(torch.fft.rfft(vec, norm=norm))] # pylint:disable=not-callable
58
+
59
+ # multidimensional fft for each parameter
60
+ return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
61
+
62
+ @torch.no_grad
63
+ def unproject(self, tensors, vars, current):
64
+ settings = self.settings[vars.params[0]]
65
+ one_d = settings['one_d']
66
+ norm = settings['norm']
67
+
68
+ if one_d:
69
+ vec = torch.view_as_complex(tensors[0])
70
+ unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
71
+ return vec_to_tensors(unprojected_vec, reference=vars.params)
72
+
73
+ return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors, vars.params)] # pylint:disable=not-callable
@@ -0,0 +1,10 @@
1
+ import importlib.util
2
+ import warnings
3
+ from collections.abc import Callable, Mapping
4
+ from operator import itemgetter
5
+ from typing import Any, Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Chainable, Module, Vars
10
+ from .projection import Projection