torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,350 @@
1
+ from operator import itemgetter
2
+
3
+ import torch
4
+ from typing import Literal
5
+ from ...core import Chainable, Transform, apply
6
+ from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+
8
+ @torch.no_grad
9
+ def update_soap_covariances_(
10
+ g1: torch.Tensor,
11
+ g2: torch.Tensor,
12
+ GGs_: list[torch.Tensor | None],
13
+ beta: float | None,
14
+ ):
15
+ for i, GG in enumerate(GGs_):
16
+ if GG is None: continue
17
+
18
+ axes = list(range(i)) + list(range(i + 1, g1.ndim)) # this works fine with 1d params
19
+ if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
20
+ else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
21
+
22
+ @torch.no_grad
23
+ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
24
+ """
25
+ Projects the gradient to the eigenbases of the preconditioner.
26
+ """
27
+ for mat in Q:
28
+ if mat is None: continue
29
+ if len(mat) > 0:
30
+ tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
31
+ else:
32
+ # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
33
+ permute_order = list(range(1, len(tensors.shape))) + [0]
34
+ tensors = tensors.permute(permute_order)
35
+
36
+ return tensors
37
+
38
+ @torch.no_grad
39
+ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
40
+ """
41
+ Projects the gradient back to the original space.
42
+ """
43
+ for mat in Q:
44
+ if mat is None: continue
45
+ if len(mat) > 0:
46
+ tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
47
+ else:
48
+ permute_order = list(range(1, len(tensors.shape))) + [0]
49
+ tensors = tensors.permute(permute_order)
50
+
51
+ return tensors
52
+
53
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
54
+ @torch.no_grad
55
+ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
56
+ """
57
+ Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
58
+ """
59
+ matrix = []
60
+ float_data = False
61
+ original_type = original_device = None
62
+ for m in mat:
63
+ if m is None: continue
64
+ if len(m) == 0:
65
+ matrix.append([])
66
+ continue
67
+ if m.dtype != torch.float:
68
+ original_type = m.dtype
69
+ original_device = m.device
70
+ matrix.append(m.float())
71
+ else:
72
+ float_data = True
73
+ matrix.append(m)
74
+
75
+ final = []
76
+ for m in matrix:
77
+ if len(m) == 0:
78
+ final.append([])
79
+ continue
80
+ try:
81
+ _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
82
+ except Exception:
83
+ _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
84
+ Q = Q.to(m.dtype)
85
+ Q = torch.flip(Q, [1])
86
+
87
+ if not float_data:
88
+ Q = Q.to(original_device).type(original_type)
89
+ final.append(Q)
90
+ return final
91
+
92
+ # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
93
+ @torch.no_grad
94
+ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
95
+ """
96
+ Computes the eigenbases of the preconditioner using one round of power iteration
97
+ followed by torch.linalg.qr decomposition.
98
+ """
99
+ matrix = []
100
+ orth_matrix = []
101
+ float_data = False
102
+ original_type = original_device = None
103
+ for m,o in zip(GG, Q_list):
104
+ if m is None: continue
105
+ assert o is not None
106
+
107
+ if len(m) == 0:
108
+ matrix.append([])
109
+ orth_matrix.append([])
110
+ continue
111
+ if m.data.dtype != torch.float:
112
+ original_type = m.data.dtype
113
+ original_device = m.data.device
114
+ matrix.append(m.data.float())
115
+ orth_matrix.append(o.data.float())
116
+ else:
117
+ float_data = True
118
+ matrix.append(m.data.float())
119
+ orth_matrix.append(o.data.float())
120
+
121
+ final = []
122
+ for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
123
+ if len(m)==0:
124
+ final.append([])
125
+ continue
126
+ est_eig = torch.diag(o.T @ m @ o)
127
+ sort_idx = torch.argsort(est_eig, descending=True)
128
+ exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
129
+ o = o[:,sort_idx]
130
+ power_iter = m @ o
131
+ Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
132
+
133
+ if not float_data:
134
+ Q = Q.to(original_device).type(original_type)
135
+ final.append(Q)
136
+
137
+ return final, exp_avg_sq
138
+
139
+ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
140
+ class ABSOAP(Transform):
141
+ """SOAP but with two extra letters included in its name in order to improve converence
142
+
143
+ new args
144
+
145
+ scale by s whether to scale gradient differences by parameter differences
146
+
147
+ y_to_ema2 whether to use gradient differences for exponential moving average too
148
+ """
149
+ def __init__(
150
+ self,
151
+ beta1: float = 0.95,
152
+ beta2: float = 0.95,
153
+ shampoo_beta: float | None = 0.95,
154
+ precond_freq: int = 10,
155
+ merge_small: bool = True,
156
+ max_dim: int = 2_000,
157
+ precondition_1d: bool = True,
158
+ eps: float = 1e-8,
159
+ decay: float | None = None,
160
+ alpha: float = 1,
161
+ bias_correction: bool = True,
162
+ scale_by_s: bool = True,
163
+ first: Source='g',
164
+ second: Source='g',
165
+ ema1: Source='g',
166
+ ema2: tuple[Source, Source] = ('g','g'),
167
+ rel1: bool=False,
168
+ rel2: bool=False,
169
+ norm: bool = False,
170
+ ):
171
+ defaults = dict(
172
+ beta1=beta1,
173
+ beta2=beta2,
174
+ shampoo_beta=shampoo_beta,
175
+ precond_freq=precond_freq,
176
+ merge_small=merge_small,
177
+ max_dim=max_dim,
178
+ precondition_1d=precondition_1d,
179
+ eps=eps,
180
+ decay=decay,
181
+ bias_correction=bias_correction,
182
+ alpha=alpha,
183
+ scale_by_s=scale_by_s,
184
+ ema1=ema1,
185
+ ema2=ema2,
186
+ first=first,
187
+ second=second,
188
+ rel1=rel1, rel2=rel2,
189
+ norm=norm,
190
+ )
191
+ super().__init__(defaults, uses_grad=False)
192
+
193
+ @torch.no_grad
194
+ def transform(self, tensors, params, grads, vars):
195
+ updates = []
196
+ # update preconditioners
197
+ for i,(p,t) in enumerate(zip(params, tensors)):
198
+ state = self.state[p]
199
+ settings = self.settings[p]
200
+ beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
201
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
202
+ scale_by_s = settings['scale_by_s']
203
+ ema1 = settings['ema1']
204
+ ema2 = settings['ema2']
205
+ first=settings['first']
206
+ second=settings['second']
207
+ rel1 = settings['rel1']; rel2 = settings['rel2']
208
+ norm=settings['norm']
209
+
210
+ if merge_small:
211
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
212
+
213
+ if 'g_prev' not in state:
214
+ state['p_prev'] = p.clone()
215
+ state['g_prev'] = t.clone()
216
+ updates.append(tensors[i].sign())
217
+ continue
218
+
219
+ p_prev = state['p_prev']
220
+ g_prev = state['g_prev']
221
+ s = p - p_prev
222
+ y = t - g_prev
223
+
224
+ # keep malding
225
+ p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
226
+ g_norm = torch.linalg.vector_norm(t) # pylint:disable=not-callable
227
+ s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
228
+ y_norm = torch.linalg.vector_norm(y) # pylint:disable=not-callable
229
+
230
+ sn = p - p_prev * (p_norm / torch.linalg.vector_norm(p_prev))# pylint:disable=not-callable
231
+ yn = t - g_prev * (g_norm / torch.linalg.vector_norm(g_prev))# pylint:disable=not-callable
232
+
233
+ if scale_by_s: y /= s_norm.clip(min=1e-8) # pylint:disable=not-callable
234
+
235
+ state['p_prev'].copy_(p)
236
+ state['g_prev'].copy_(t)
237
+
238
+ def _get(c: Source):
239
+ if c == 'p': return p
240
+ if c == 'g': return t
241
+ if c == 's': return s
242
+ if c == 'y': return y
243
+ if c == 'sn': return sn
244
+ if c == 'yn': return yn
245
+ if c == 'gy': return t+y
246
+ if c == 'sy': return s+y
247
+ if c == 'gys':
248
+ y_scaled = y * (g_norm/y_norm.clip(min=1e-8))
249
+ return t+y_scaled
250
+ if c == 'sys':
251
+ y_scaled = y * (s_norm/y_norm.clip(min=1e-8))
252
+ return s+y_scaled
253
+ raise RuntimeError("Big Chungus")
254
+
255
+ t1 = _get(first)
256
+ if rel1: t1 = t1 * p.abs().clip(min=1e-6)
257
+ t2 = _get(second)
258
+ if rel2: t2 = t2 * p.abs().clip(min=1e-6)
259
+
260
+ t_ema1 = _get(ema1)
261
+ t_ema2s = _get(ema2[0]), _get(ema2[1])
262
+
263
+ if norm:
264
+ t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
265
+ t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
266
+
267
+
268
+ # initialize state on 1st step
269
+ if 'GG' not in state:
270
+ state["exp_avg"] = torch.zeros_like(t)
271
+ state["exp_avg_sq"] = torch.ones_like(t)
272
+
273
+ if not precondition_1d and t.ndim <= 1:
274
+ state['GG'] = []
275
+
276
+ else:
277
+ state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
278
+
279
+ # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
280
+ if len([i is not None for i in state['GG']]) == 0:
281
+ state['GG'] = None
282
+
283
+ if state['GG'] is not None:
284
+ update_soap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
285
+ state['Q'] = get_orthogonal_matrix(state['GG'])
286
+
287
+ state['step'] = 0
288
+ updates.append(tensors[i].sign())
289
+ continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
290
+ # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
291
+
292
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
293
+ # i.e. projecting to the eigenbases of matrices in state['GG']
294
+ z1_projected = None
295
+ z2_projected = None
296
+
297
+ if state['GG'] is not None:
298
+ z1_projected = project(t_ema2s[0], state['Q'])
299
+ if ema2[0] == ema2[1]: z2_projected = z1_projected
300
+ else: z2_projected = project(t_ema2s[1], state['Q'])
301
+
302
+ # exponential moving averages
303
+ # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
304
+ exp_avg: torch.Tensor = state["exp_avg"]
305
+ exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
306
+
307
+ exp_avg.lerp_(t_ema1, 1-beta1)
308
+
309
+ if z1_projected is None:
310
+ exp_avg_sq.mul_(beta2).addcmul_(*t_ema2s, value=1-beta2)
311
+ else:
312
+ assert z2_projected is not None
313
+ exp_avg_sq.mul_(beta2).addcmul_(z1_projected, z2_projected, value=1-beta2)
314
+
315
+ # project exponential moving averages if they are accumulated unprojected
316
+ exp_avg_projected = exp_avg
317
+ if z1_projected is not None:
318
+ exp_avg_projected = project(exp_avg, state['Q'])
319
+
320
+ exp_avg_sq_projected = exp_avg_sq
321
+
322
+ denom = exp_avg_sq_projected.sqrt().add_(eps)
323
+ # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
324
+
325
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
326
+ # to the original space
327
+ update = exp_avg_projected / denom
328
+ if z1_projected is not None:
329
+ update = project_back(update, state["Q"])
330
+
331
+ if settings['bias_correction']:
332
+ bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
333
+ bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
334
+ update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
335
+ elif alpha is not None:
336
+ update *= alpha
337
+
338
+ if merge_small:
339
+ update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
340
+
341
+ updates.append(update)
342
+ state["step"] += 1
343
+
344
+ # Update is done after the gradient step to avoid using current gradients in the projection.
345
+ if state['GG'] is not None:
346
+ update_soap_covariances_(t1, t2, state['GG'], shampoo_beta)
347
+ if state['step'] % settings['precond_freq'] == 0:
348
+ state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
349
+
350
+ return updates
@@ -0,0 +1,111 @@
1
+ from operator import itemgetter
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform
7
+ from ...utils import NumberList, TensorList
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..lr.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def adadam_(
19
+ tensors: TensorList,
20
+ exp_avg_: TensorList,
21
+ exp_avg_sq_: TensorList,
22
+ exp_avg_qu_: TensorList,
23
+ alpha: float | NumberList,
24
+ beta1: float | NumberList,
25
+ beta2: float | NumberList,
26
+ precond_beta: float | NumberList,
27
+ eps: float | NumberList,
28
+ step: int,
29
+ pow: float = 2,
30
+ debiased: bool = True,
31
+ max_exp_avg_sq_: TensorList | None = None,
32
+ max_exp_avg_qu_: TensorList | None = None,
33
+ params_: TensorList | None = None,
34
+ ):
35
+ """Returns new tensors or updates params in-place."""
36
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
37
+
38
+ sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
39
+ debiased=False,step=step,pow=pow)
40
+ sqrt_exp_avg_qu = sqrt_ema_sq_(tensors/(sqrt_exp_avg_sq+1e-8), exp_avg_sq_=exp_avg_qu_,
41
+ beta=precond_beta,max_exp_avg_sq_=max_exp_avg_qu_, debiased=False,step=step,pow=pow)
42
+
43
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
44
+
45
+ # params is None, return update
46
+ if params_ is None: return (exp_avg_ / sqrt_exp_avg_qu.add_(eps)).lazy_mul(alpha)
47
+
48
+ # update params in-place
49
+ params_.addcdiv_(exp_avg_, sqrt_exp_avg_qu.add_(eps), -alpha)
50
+ return None
51
+
52
+ class Adadam(Module):
53
+ """Adam with a diagonally preconditioned preconditioner and a graceful name."""
54
+ def __init__(
55
+ self,
56
+ beta1: float = 0.9,
57
+ beta2: float = 0.999,
58
+ precond_beta: float = 0.999,
59
+ eps: float = 1e-8,
60
+ amsgrad: bool = False,
61
+ alpha: float = 1.,
62
+ pow: float = 2,
63
+ debiased: bool = True,
64
+ ):
65
+ defaults=dict(beta1=beta1,beta2=beta2,precond_beta=precond_beta,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
66
+ super().__init__(defaults)
67
+ self.getter = itemgetter('amsgrad','pow','debiased')
68
+
69
+ @torch.no_grad
70
+ def step(self, vars):
71
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
+
73
+ beta1,beta2,precond_beta,eps,alpha=self.get_settings('beta1','beta2','precond_beta','eps','alpha', params=vars.params, cls=NumberList)
74
+ amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
75
+
76
+ if amsgrad:
77
+ exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', params=vars.params, cls=TensorList)
78
+ else:
79
+ exp_avg, exp_avg_sq, exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', params=vars.params, cls=TensorList)
80
+ max_exp_avg_sq = None
81
+ max_exp_avg_qu = None
82
+
83
+ # if this is last module, update parameters in-place with slightly more efficient addcdiv_
84
+ if vars.is_last:
85
+ if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
86
+ passed_params = TensorList(vars.params)
87
+ vars.stop = True
88
+ vars.skip_update = True
89
+
90
+ else:
91
+ passed_params = None
92
+
93
+ vars.update = adadam_(
94
+ tensors=TensorList(vars.get_update()),
95
+ exp_avg_=exp_avg,
96
+ exp_avg_sq_=exp_avg_sq,
97
+ exp_avg_qu_=exp_avg_qu,
98
+ alpha=alpha,
99
+ beta1=beta1,
100
+ beta2=beta2,
101
+ precond_beta=precond_beta,
102
+ eps=eps,
103
+ step=step,
104
+ pow=pow,
105
+ debiased=debiased,
106
+ max_exp_avg_sq_=max_exp_avg_sq,
107
+ max_exp_avg_qu_=max_exp_avg_qu,
108
+ params_=passed_params,
109
+ )
110
+
111
+ return vars
@@ -0,0 +1,135 @@
1
+ from operator import itemgetter
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform
7
+ from ...utils import NumberList, TensorList
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..lr.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def adamy_(
19
+ p: TensorList,
20
+ p_prev: TensorList,
21
+ g: TensorList,
22
+ g_prev: TensorList,
23
+ exp_avg_: TensorList,
24
+ exp_avg_sq_: TensorList,
25
+ alpha: float | NumberList,
26
+ beta1: float | NumberList,
27
+ beta2: float | NumberList,
28
+ eps: float | NumberList,
29
+ step: int,
30
+ pow: float = 2,
31
+ debiased: bool = True,
32
+ max_exp_avg_sq_: TensorList | None = None,
33
+ params_: TensorList | None = None,
34
+ ):
35
+ """Returns new tensors or updates params in-place."""
36
+ if step == 1:
37
+ p_prev.copy_(p)
38
+ g_prev.copy_(g)
39
+
40
+ update = g.sign().lazy_mul_(alpha*0.1)
41
+ if params_ is None: return update
42
+ params_.sub_(update)
43
+ return None
44
+
45
+ s = p-p_prev
46
+ y = (g-g_prev).div_(s.global_vector_norm().clip(min=1e-8))
47
+ p_prev.copy_(p)
48
+ g_prev.copy_(g)
49
+
50
+ exp_avg_ = ema_(g, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
51
+
52
+ sqrt_exp_avg_sq = sqrt_ema_sq_(y, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
53
+ debiased=False,step=step,pow=pow)
54
+
55
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
56
+
57
+ # params is None, return update
58
+ if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
59
+
60
+ # update params in-place
61
+ params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
62
+ return None
63
+
64
+ class AdamY(Module):
65
+ """Adam but uses scaled gradient differences for second momentum.
66
+
67
+ Args:
68
+ beta1 (float, optional): momentum. Defaults to 0.9.
69
+ beta2 (float, optional): second momentum. Defaults to 0.999.
70
+ eps (float, optional): epsilon. Defaults to 1e-8.
71
+ alpha (float, optional): learning rate. Defaults to 1.
72
+ amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
73
+ pow (float, optional): power used in second momentum power and root. Defaults to 2.
74
+ debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
75
+ """
76
+ def __init__(
77
+ self,
78
+ beta1: float = 0.9,
79
+ beta2: float = 0.999,
80
+ eps: float = 1e-8,
81
+ amsgrad: bool = False,
82
+ alpha: float = 1.,
83
+ pow: float = 2,
84
+ debiased: bool = True,
85
+ ):
86
+ defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
87
+ super().__init__(defaults)
88
+ self.getter = itemgetter('amsgrad','pow','debiased')
89
+
90
+ @torch.no_grad
91
+ def step(self, vars):
92
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
93
+
94
+ beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
95
+ amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
96
+
97
+ if amsgrad:
98
+ exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
99
+ else:
100
+ exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
101
+ max_exp_avg_sq = None
102
+
103
+ # if this is last module, update parameters in-place with slightly more efficient addcdiv_
104
+ if vars.is_last:
105
+ if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
106
+ passed_params = TensorList(vars.params)
107
+ vars.stop = True
108
+ vars.skip_update = True
109
+
110
+ else:
111
+ passed_params = None
112
+
113
+ p_prev = self.get_state('p_prev', params=vars.params, cls=TensorList)
114
+ g_prev = self.get_state('g_prev', params=vars.params, cls=TensorList)
115
+
116
+
117
+ vars.update = adamy_(
118
+ p=TensorList(vars.params),
119
+ p_prev=p_prev,
120
+ g=TensorList(vars.get_update()),
121
+ g_prev=g_prev,
122
+ exp_avg_=exp_avg,
123
+ exp_avg_sq_=exp_avg_sq,
124
+ alpha=alpha,
125
+ beta1=beta1,
126
+ beta2=beta2,
127
+ eps=eps,
128
+ step=step,
129
+ pow=pow,
130
+ debiased=debiased,
131
+ max_exp_avg_sq_=max_exp_avg_sq,
132
+ params_=passed_params,
133
+ )
134
+
135
+ return vars