torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,250 +0,0 @@
1
- from operator import itemgetter
2
- from typing import Literal
3
-
4
- import torch
5
-
6
- from ...core import Chainable, Transform
7
- from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
8
- from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
9
-
10
- @torch.no_grad
11
- def update_absoap_covariances_(
12
- g1: torch.Tensor,
13
- g2: torch.Tensor,
14
- GGs_: list[torch.Tensor | None],
15
- beta: float | None,
16
- ):
17
- for i, GG in enumerate(GGs_):
18
- if GG is None: continue
19
-
20
- axes = list(range(i)) + list(range(i + 1, g1.ndim)) # this works fine with 1d params
21
- if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
22
- else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
23
-
24
-
25
- Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
26
- class ABSOAP(Transform):
27
- """SOAP but with some extra options for testing. Please note that this is experimental and isn't guaranteed to work.
28
-
29
- Args:
30
- scale_by_s - whether to scale y by s
31
- gg1 - 1st vector into GGᵀ
32
- gg2 - 2nd vector into GGᵀ
33
- ema1 - vector into 1st momentum
34
- ema2 - 2 vectors into 2nd momentum
35
- rel1 - if True, multiplies gg1 by params
36
- rel2 - same but for gg2
37
- norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
38
-
39
- letters:
40
- p - params
41
- g - grad
42
- s - param difference
43
- y - grad difference
44
- gy - g+y
45
- sy - s+y
46
- sn - s normalized
47
- yn - y normalized
48
- gys - g + y#g
49
- sys - s + y#s
50
-
51
- """
52
- def __init__(
53
- self,
54
- beta1: float = 0.95,
55
- beta2: float = 0.95,
56
- shampoo_beta: float | None = 0.95,
57
- precond_freq: int = 10,
58
- merge_small: bool = True,
59
- max_dim: int = 2_000,
60
- precondition_1d: bool = True,
61
- eps: float = 1e-8,
62
- decay: float | None = None,
63
- alpha: float = 1,
64
- bias_correction: bool = True,
65
- scale_by_s: bool = True,
66
- gg1: Source='g',
67
- gg2: Source='g',
68
- ema1: Source='g',
69
- ema2: tuple[Source, Source] = ('g','g'),
70
- rel1: bool=False,
71
- rel2: bool=False,
72
- norm: bool = False,
73
- ):
74
- defaults = dict(
75
- beta1=beta1,
76
- beta2=beta2,
77
- shampoo_beta=shampoo_beta,
78
- precond_freq=precond_freq,
79
- merge_small=merge_small,
80
- max_dim=max_dim,
81
- precondition_1d=precondition_1d,
82
- eps=eps,
83
- decay=decay,
84
- bias_correction=bias_correction,
85
- alpha=alpha,
86
- scale_by_s=scale_by_s,
87
- ema1=ema1,
88
- ema2=ema2,
89
- first=gg1,
90
- second=gg2,
91
- rel1=rel1, rel2=rel2,
92
- norm=norm,
93
- )
94
- super().__init__(defaults, uses_grad=False)
95
-
96
- @torch.no_grad
97
- def apply(self, tensors, params, grads, loss, states, settings):
98
- updates = []
99
- # update preconditioners
100
- for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
101
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
102
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
103
- scale_by_s = setting['scale_by_s']
104
- ema1 = setting['ema1']
105
- ema2 = setting['ema2']
106
- first=setting['first']
107
- second=setting['second']
108
- rel1 = setting['rel1']; rel2 = setting['rel2']
109
- norm=setting['norm']
110
-
111
- if merge_small:
112
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
113
-
114
- if 'g_prev' not in state:
115
- state['p_prev'] = p.clone()
116
- state['g_prev'] = t.clone()
117
- # updates.append(tensors[i].clip(-0.1,0.1))
118
- # continue
119
-
120
- p_prev = state['p_prev']
121
- g_prev = state['g_prev']
122
- s = p - p_prev
123
- y = t - g_prev
124
-
125
- # keep malding
126
- p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
127
- g_norm = torch.linalg.vector_norm(t) # pylint:disable=not-callable
128
- s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
129
- y_norm = torch.linalg.vector_norm(y) # pylint:disable=not-callable
130
-
131
- sn = p - p_prev * (p_norm / torch.linalg.vector_norm(p_prev))# pylint:disable=not-callable
132
- yn = t - g_prev * (g_norm / torch.linalg.vector_norm(g_prev))# pylint:disable=not-callable
133
-
134
- if scale_by_s: y /= s_norm.clip(min=1e-8) # pylint:disable=not-callable
135
-
136
- state['p_prev'].copy_(p)
137
- state['g_prev'].copy_(t)
138
-
139
- def _get(c: Source):
140
- if c == 'p': return p
141
- if c == 'g': return t
142
- if c == 's': return s
143
- if c == 'y': return y
144
- if c == 'sn': return sn
145
- if c == 'yn': return yn
146
- if c == 'gy': return t+y
147
- if c == 'sy': return s+y
148
- if c == 'gys':
149
- y_scaled = y * (g_norm/y_norm.clip(min=1e-8))
150
- return t+y_scaled
151
- if c == 'sys':
152
- y_scaled = y * (s_norm/y_norm.clip(min=1e-8))
153
- return s+y_scaled
154
- raise RuntimeError("Big Chungus")
155
-
156
- t1 = _get(first)
157
- if rel1: t1 = t1 * p.abs().clip(min=1e-6)
158
- t2 = _get(second)
159
- if rel2: t2 = t2 * p.abs().clip(min=1e-6)
160
-
161
- t_ema1 = _get(ema1)
162
- t_ema2s = _get(ema2[0]), _get(ema2[1])
163
-
164
- if norm:
165
- t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
166
- t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
167
-
168
- # initialize state on 1st step
169
- if 'GG' not in state:
170
- state["exp_avg"] = torch.zeros_like(t)
171
- state["exp_avg_sq"] = torch.zeros_like(t)
172
-
173
- if not precondition_1d and t.ndim <= 1:
174
- state['GG'] = []
175
-
176
- else:
177
- state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
178
-
179
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
180
- if len([i is not None for i in state['GG']]) == 0:
181
- state['GG'] = None
182
-
183
- if state['GG'] is not None:
184
- update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
185
- state['Q'] = get_orthogonal_matrix(state['GG'])
186
-
187
- state['step'] = 0
188
- updates.append(tensors[i].clip(-0.1,0.1))
189
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
190
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
191
-
192
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
193
- # i.e. projecting to the eigenbases of matrices in state['GG']
194
- z1_projected = None
195
- z2_projected = None
196
-
197
- if state['GG'] is not None:
198
- z1_projected = project(t_ema2s[0], state['Q'])
199
- if ema2[0] == ema2[1]: z2_projected = z1_projected
200
- else: z2_projected = project(t_ema2s[1], state['Q'])
201
-
202
- # exponential moving averages
203
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
204
- exp_avg: torch.Tensor = state["exp_avg"]
205
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
206
-
207
- exp_avg.lerp_(t_ema1, 1-beta1)
208
-
209
- if z1_projected is None:
210
- exp_avg_sq.mul_(beta2).addcmul_(*t_ema2s, value=1-beta2)
211
- else:
212
- assert z2_projected is not None
213
- exp_avg_sq.mul_(beta2).addcmul_(z1_projected, z2_projected, value=1-beta2)
214
-
215
- # project exponential moving averages if they are accumulated unprojected
216
- exp_avg_projected = exp_avg
217
- if z1_projected is not None:
218
- exp_avg_projected = project(exp_avg, state['Q'])
219
-
220
- exp_avg_sq_projected = exp_avg_sq
221
-
222
- denom = exp_avg_sq_projected.sqrt().add_(eps)
223
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
224
-
225
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
226
- # to the original space
227
- update = exp_avg_projected / denom
228
- if z1_projected is not None:
229
- update = project_back(update, state["Q"])
230
-
231
- if setting['bias_correction']:
232
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
233
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
234
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
235
- elif alpha is not None:
236
- update *= alpha
237
-
238
- if merge_small:
239
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
240
-
241
- updates.append(update)
242
- state["step"] += 1
243
-
244
- # Update is done after the gradient step to avoid using current gradients in the projection.
245
- if state['GG'] is not None:
246
- update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
247
- if state['step'] % setting['precond_freq'] == 0:
248
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
249
-
250
- return updates
@@ -1,112 +0,0 @@
1
- from operator import itemgetter
2
- from functools import partial
3
-
4
- import torch
5
-
6
- from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
- from ..lr.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
-
17
-
18
- def adadam_(
19
- tensors: TensorList,
20
- exp_avg_: TensorList,
21
- exp_avg_sq_: TensorList,
22
- exp_avg_qu_: TensorList,
23
- alpha: float | NumberList,
24
- beta1: float | NumberList,
25
- beta2: float | NumberList,
26
- precond_beta: float | NumberList,
27
- eps: float | NumberList,
28
- step: int,
29
- pow: float = 2,
30
- debiased: bool = True,
31
- max_exp_avg_sq_: TensorList | None = None,
32
- max_exp_avg_qu_: TensorList | None = None,
33
- params_: TensorList | None = None,
34
- ):
35
- """Returns new tensors or updates params in-place."""
36
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
37
-
38
- sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
39
- debiased=False,step=step,pow=pow)
40
- sqrt_exp_avg_qu = sqrt_ema_sq_(tensors/(sqrt_exp_avg_sq+1e-8), exp_avg_sq_=exp_avg_qu_,
41
- beta=precond_beta,max_exp_avg_sq_=max_exp_avg_qu_, debiased=False,step=step,pow=pow)
42
-
43
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
44
-
45
- # params is None, return update
46
- if params_ is None: return (exp_avg_ / sqrt_exp_avg_qu.add_(eps)).lazy_mul(alpha)
47
-
48
- # update params in-place
49
- params_.addcdiv_(exp_avg_, sqrt_exp_avg_qu.add_(eps), -alpha)
50
- return None
51
-
52
- class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner. Please note that this is experimental and isn't guaranteed to work."""
54
- def __init__(
55
- self,
56
- beta1: float = 0.9,
57
- beta2: float = 0.999,
58
- precond_beta: float = 0.999,
59
- eps: float = 1e-8,
60
- amsgrad: bool = False,
61
- alpha: float = 1.,
62
- pow: float = 2,
63
- debiased: bool = True,
64
- ):
65
- defaults=dict(beta1=beta1,beta2=beta2,precond_beta=precond_beta,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
66
- super().__init__(defaults)
67
- self.getter = itemgetter('amsgrad','pow','debiased')
68
-
69
- @torch.no_grad
70
- def step(self, var):
71
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
- params = var.params
73
-
74
- beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
75
- amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
76
-
77
- if amsgrad:
78
- exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
79
- else:
80
- exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
81
- max_exp_avg_sq = None
82
- max_exp_avg_qu = None
83
-
84
- # if this is last module, update parameters in-place with slightly more efficient addcdiv_
85
- if var.is_last:
86
- if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
87
- passed_params = TensorList(var.params)
88
- var.stop = True
89
- var.skip_update = True
90
-
91
- else:
92
- passed_params = None
93
-
94
- var.update = adadam_(
95
- tensors=TensorList(var.get_update()),
96
- exp_avg_=exp_avg,
97
- exp_avg_sq_=exp_avg_sq,
98
- exp_avg_qu_=exp_avg_qu,
99
- alpha=alpha,
100
- beta1=beta1,
101
- beta2=beta2,
102
- precond_beta=precond_beta,
103
- eps=eps,
104
- step=step,
105
- pow=pow,
106
- debiased=debiased,
107
- max_exp_avg_sq_=max_exp_avg_sq,
108
- max_exp_avg_qu_=max_exp_avg_qu,
109
- params_=passed_params,
110
- )
111
-
112
- return var
@@ -1,125 +0,0 @@
1
- from operator import itemgetter
2
- from functools import partial
3
-
4
- import torch
5
-
6
- from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
- from ..lr.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
-
17
-
18
- def adamy_(
19
- p: TensorList,
20
- p_prev: TensorList,
21
- g: TensorList,
22
- g_prev: TensorList,
23
- exp_avg_: TensorList,
24
- exp_avg_sq_: TensorList,
25
- alpha: float | NumberList,
26
- beta1: float | NumberList,
27
- beta2: float | NumberList,
28
- eps: float | NumberList,
29
- step: int,
30
- pow: float = 2,
31
- debiased: bool = True,
32
- max_exp_avg_sq_: TensorList | None = None,
33
- params_: TensorList | None = None,
34
- ):
35
- """Returns new tensors or updates params in-place."""
36
- if step == 1:
37
- p_prev.copy_(p)
38
- g_prev.copy_(g)
39
-
40
- update = g.clip(-0.1,0.1).lazy_mul_(alpha)
41
- if params_ is None: return update
42
- params_.sub_(update)
43
- return None
44
-
45
- s = p-p_prev
46
- y = (g-g_prev).div_(s.global_vector_norm().clip(min=1e-8))
47
- p_prev.copy_(p)
48
- g_prev.copy_(g)
49
-
50
- exp_avg_ = ema_(g, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
51
-
52
- sqrt_exp_avg_sq = sqrt_ema_sq_(y, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
53
- debiased=False,step=step,pow=pow)
54
-
55
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
56
-
57
- # params is None, return update
58
- if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
59
-
60
- # update params in-place
61
- params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
62
- return None
63
-
64
- class AdamY(Module):
65
- """Adam but uses scaled gradient differences for second momentum. Please note that this is experimental and isn't guaranteed to work."""
66
- def __init__(
67
- self,
68
- beta1: float = 0.9,
69
- beta2: float = 0.999,
70
- eps: float = 1e-8,
71
- amsgrad: bool = False,
72
- alpha: float = 1.,
73
- pow: float = 2,
74
- debiased: bool = True,
75
- ):
76
- defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
77
- super().__init__(defaults)
78
- self.getter = itemgetter('amsgrad','pow','debiased')
79
-
80
- @torch.no_grad
81
- def step(self, var):
82
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
83
-
84
- beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
85
- amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
86
-
87
- if amsgrad:
88
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
89
- else:
90
- exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
91
- max_exp_avg_sq = None
92
-
93
- # if this is last module, update parameters in-place with slightly more efficient addcdiv_
94
- if var.is_last:
95
- if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
96
- passed_params = TensorList(var.params)
97
- var.stop = True
98
- var.skip_update = True
99
-
100
- else:
101
- passed_params = None
102
-
103
- p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
104
- g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
105
-
106
-
107
- var.update = adamy_(
108
- p=TensorList(var.params),
109
- p_prev=p_prev,
110
- g=TensorList(var.get_update()),
111
- g_prev=g_prev,
112
- exp_avg_=exp_avg,
113
- exp_avg_sq_=exp_avg_sq,
114
- alpha=alpha,
115
- beta1=beta1,
116
- beta2=beta2,
117
- eps=eps,
118
- step=step,
119
- pow=pow,
120
- debiased=debiased,
121
- max_exp_avg_sq_=max_exp_avg_sq,
122
- params_=passed_params,
123
- )
124
-
125
- return var
@@ -1,172 +0,0 @@
1
- from operator import itemgetter
2
-
3
- import torch
4
-
5
- from ...core import Chainable, Transform
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
- from ..optimizers.soap import (
8
- get_orthogonal_matrix,
9
- get_orthogonal_matrix_QR,
10
- project,
11
- project_back,
12
- )
13
-
14
-
15
- @torch.no_grad
16
- def update_adasoap_covariances_(
17
- grad: torch.Tensor,
18
- GGs_: list[torch.Tensor | None],
19
- GG_sqs: list[torch.Tensor | None],
20
- beta: float | None,
21
- precond_beta: float | None,
22
- ):
23
- for i, (GG, GG_sq) in enumerate(zip(GGs_, GG_sqs)):
24
- if GG is None: continue
25
- assert GG_sq is not None
26
-
27
- if precond_beta is None: GG_sq.addcmul_(GG, GG)
28
- else: GG_sq.mul_(precond_beta).addcmul_(GG, GG, value=1-precond_beta)
29
-
30
- axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
31
- if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
32
- else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
33
-
34
-
35
- class AdaSOAP(Transform):
36
- """SOAP with diagonally preconditioned GG^Ts. Please note that this is experimental and isn't guaranteed to work.
37
-
38
- precond_beta - beta for GG^T squares
39
- """
40
- def __init__(
41
- self,
42
- beta1: float = 0.95,
43
- beta2: float = 0.95,
44
- shampoo_beta: float | None = 0.95,
45
- precond_beta: float | None = 0.95,
46
- precond_freq: int = 10,
47
- merge_small: bool = True,
48
- max_dim: int = 2_000,
49
- precondition_1d: bool = True,
50
- eps: float = 1e-8,
51
- decay: float | None = None,
52
- alpha: float = 1,
53
- unprojected_exp_avg: bool = True,
54
- bias_correction: bool = True,
55
- ):
56
- defaults = dict(
57
- beta1=beta1,
58
- beta2=beta2,
59
- shampoo_beta=shampoo_beta,
60
- precond_beta=precond_beta,
61
- precond_freq=precond_freq,
62
- merge_small=merge_small,
63
- max_dim=max_dim,
64
- precondition_1d=precondition_1d,
65
- eps=eps,
66
- decay=decay,
67
- unprojected_exp_avg=unprojected_exp_avg,
68
- bias_correction=bias_correction,
69
- alpha=alpha,
70
- )
71
- super().__init__(defaults, uses_grad=False)
72
-
73
- @torch.no_grad
74
- def apply(self, tensors, params, grads, loss, states, settings):
75
- updates = []
76
- # update preconditioners
77
- for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
78
-
79
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
80
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
81
- precond_beta = setting['precond_beta']
82
-
83
- if merge_small:
84
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
85
-
86
- # initialize state on 1st step
87
- if 'GG' not in state:
88
- state["exp_avg"] = torch.zeros_like(t)
89
- state["exp_avg_sq"] = torch.zeros_like(t)
90
-
91
- if not precondition_1d and t.ndim <= 1:
92
- state['GG'] = []
93
- state['GG_sq'] = []
94
-
95
- else:
96
- state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
97
- state['GG_sq'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
98
-
99
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
100
- if len([i is not None for i in state['GG']]) == 0:
101
- state['GG'] = None
102
- state['GG_sq'] = None
103
-
104
- if state['GG'] is not None:
105
- assert state['GG_sq'] is not None
106
- update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
107
- GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
108
- state['Q'] = get_orthogonal_matrix(GG_precond)
109
-
110
- state['step'] = 0
111
- updates.append(tensors[i].clip(-0.1,0.1))
112
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
113
- # that can mess with other modules scaling
114
-
115
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
116
- # i.e. projecting to the eigenbases of matrices in state['GG']
117
- t_projected = None
118
- if state['GG'] is not None:
119
- t_projected = project(t, state['Q'])
120
-
121
- # exponential moving averages
122
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
123
- exp_avg: torch.Tensor = state["exp_avg"]
124
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
125
-
126
- if unprojected_exp_avg or t_projected is None:
127
- exp_avg.lerp_(t, 1-beta1)
128
- else:
129
- exp_avg.lerp_(t_projected, 1-beta1)
130
-
131
- if t_projected is None:
132
- exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
133
- else:
134
- exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
135
-
136
- # project exponential moving averages if they are accumulated unprojected
137
- exp_avg_projected = exp_avg
138
- if unprojected_exp_avg and t_projected is not None:
139
- exp_avg_projected = project(exp_avg, state['Q'])
140
-
141
- exp_avg_sq_projected = exp_avg_sq
142
-
143
- denom = exp_avg_sq_projected.sqrt().add_(eps)
144
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
145
-
146
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
147
- # to the original space
148
- update = exp_avg_projected / denom
149
- if t_projected is not None:
150
- update = project_back(update, state["Q"])
151
-
152
- if setting['bias_correction']:
153
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
154
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
155
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
156
- elif alpha is not None:
157
- update *= alpha
158
-
159
- if merge_small:
160
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
161
-
162
- updates.append(update)
163
- state["step"] += 1
164
-
165
- # Update is done after the gradient step to avoid using current gradients in the projection.
166
- if state['GG'] is not None:
167
- update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
168
- GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
169
- if state['step'] % setting['precond_freq'] == 0:
170
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
171
-
172
- return updates