torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,253 +0,0 @@
1
- from operator import itemgetter
2
- from typing import Literal
3
-
4
- import torch
5
-
6
- from ...core import Chainable, Transform
7
- from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
8
- from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
9
-
10
- @torch.no_grad
11
- def update_absoap_covariances_(
12
- g1: torch.Tensor,
13
- g2: torch.Tensor,
14
- GGs_: list[torch.Tensor | None],
15
- beta: float | None,
16
- ):
17
- for i, GG in enumerate(GGs_):
18
- if GG is None: continue
19
-
20
- axes = list(range(i)) + list(range(i + 1, g1.ndim)) # this works fine with 1d params
21
- if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
22
- else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
23
-
24
-
25
- Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
26
- class ABSOAP(Transform):
27
- """SOAP but with some extra options for testing.
28
-
29
- .. warning::
30
- This module is just for testing my stupid ideas.
31
-
32
- Args:
33
- scale_by_s - whether to scale y by s
34
- gg1 - 1st vector into GGᵀ
35
- gg2 - 2nd vector into GGᵀ
36
- ema1 - vector into 1st momentum
37
- ema2 - 2 vectors into 2nd momentum
38
- rel1 - if True, multiplies gg1 by params
39
- rel2 - same but for gg2
40
- norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
41
-
42
- letters:
43
- p - params
44
- g - grad
45
- s - param difference
46
- y - grad difference
47
- gy - g+y
48
- sy - s+y
49
- sn - s normalized
50
- yn - y normalized
51
- gys - g + y#g
52
- sys - s + y#s
53
-
54
- """
55
- def __init__(
56
- self,
57
- beta1: float = 0.95,
58
- beta2: float = 0.95,
59
- shampoo_beta: float | None = 0.95,
60
- precond_freq: int = 10,
61
- merge_small: bool = True,
62
- max_dim: int = 2_000,
63
- precondition_1d: bool = True,
64
- eps: float = 1e-8,
65
- decay: float | None = None,
66
- alpha: float = 1,
67
- bias_correction: bool = True,
68
- scale_by_s: bool = True,
69
- gg1: Source='g',
70
- gg2: Source='g',
71
- ema1: Source='g',
72
- ema2: tuple[Source, Source] = ('g','g'),
73
- rel1: bool=False,
74
- rel2: bool=False,
75
- norm: bool = False,
76
- ):
77
- defaults = dict(
78
- beta1=beta1,
79
- beta2=beta2,
80
- shampoo_beta=shampoo_beta,
81
- precond_freq=precond_freq,
82
- merge_small=merge_small,
83
- max_dim=max_dim,
84
- precondition_1d=precondition_1d,
85
- eps=eps,
86
- decay=decay,
87
- bias_correction=bias_correction,
88
- alpha=alpha,
89
- scale_by_s=scale_by_s,
90
- ema1=ema1,
91
- ema2=ema2,
92
- first=gg1,
93
- second=gg2,
94
- rel1=rel1, rel2=rel2,
95
- norm=norm,
96
- )
97
- super().__init__(defaults, uses_grad=False)
98
-
99
- @torch.no_grad
100
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
101
- updates = []
102
- # update preconditioners
103
- for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
104
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
105
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
106
- scale_by_s = setting['scale_by_s']
107
- ema1 = setting['ema1']
108
- ema2 = setting['ema2']
109
- first=setting['first']
110
- second=setting['second']
111
- rel1 = setting['rel1']; rel2 = setting['rel2']
112
- norm=setting['norm']
113
-
114
- if merge_small:
115
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
116
-
117
- if 'g_prev' not in state:
118
- state['p_prev'] = p.clone()
119
- state['g_prev'] = t.clone()
120
- # updates.append(tensors[i].clip(-0.1,0.1))
121
- # continue
122
-
123
- p_prev = state['p_prev']
124
- g_prev = state['g_prev']
125
- s = p - p_prev
126
- y = t - g_prev
127
-
128
- # keep malding
129
- p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
130
- g_norm = torch.linalg.vector_norm(t) # pylint:disable=not-callable
131
- s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
132
- y_norm = torch.linalg.vector_norm(y) # pylint:disable=not-callable
133
-
134
- sn = p - p_prev * (p_norm / torch.linalg.vector_norm(p_prev))# pylint:disable=not-callable
135
- yn = t - g_prev * (g_norm / torch.linalg.vector_norm(g_prev))# pylint:disable=not-callable
136
-
137
- if scale_by_s: y /= s_norm.clip(min=1e-8) # pylint:disable=not-callable
138
-
139
- state['p_prev'].copy_(p)
140
- state['g_prev'].copy_(t)
141
-
142
- def _get(c: Source):
143
- if c == 'p': return p
144
- if c == 'g': return t
145
- if c == 's': return s
146
- if c == 'y': return y
147
- if c == 'sn': return sn
148
- if c == 'yn': return yn
149
- if c == 'gy': return t+y
150
- if c == 'sy': return s+y
151
- if c == 'gys':
152
- y_scaled = y * (g_norm/y_norm.clip(min=1e-8))
153
- return t+y_scaled
154
- if c == 'sys':
155
- y_scaled = y * (s_norm/y_norm.clip(min=1e-8))
156
- return s+y_scaled
157
- raise RuntimeError("Big Chungus")
158
-
159
- t1 = _get(first)
160
- if rel1: t1 = t1 * p.abs().clip(min=1e-6)
161
- t2 = _get(second)
162
- if rel2: t2 = t2 * p.abs().clip(min=1e-6)
163
-
164
- t_ema1 = _get(ema1)
165
- t_ema2s = _get(ema2[0]), _get(ema2[1])
166
-
167
- if norm:
168
- t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
169
- t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
170
-
171
- # initialize state on 1st step
172
- if 'GG' not in state:
173
- state["exp_avg"] = torch.zeros_like(t)
174
- state["exp_avg_sq"] = torch.zeros_like(t)
175
-
176
- if not precondition_1d and t.ndim <= 1:
177
- state['GG'] = []
178
-
179
- else:
180
- state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
181
-
182
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
183
- if len([i is not None for i in state['GG']]) == 0:
184
- state['GG'] = None
185
-
186
- if state['GG'] is not None:
187
- update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
188
- state['Q'] = get_orthogonal_matrix(state['GG'])
189
-
190
- state['step'] = 0
191
- updates.append(tensors[i].clip(-0.1,0.1))
192
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
193
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
194
-
195
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
196
- # i.e. projecting to the eigenbases of matrices in state['GG']
197
- z1_projected = None
198
- z2_projected = None
199
-
200
- if state['GG'] is not None:
201
- z1_projected = project(t_ema2s[0], state['Q'])
202
- if ema2[0] == ema2[1]: z2_projected = z1_projected
203
- else: z2_projected = project(t_ema2s[1], state['Q'])
204
-
205
- # exponential moving averages
206
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
207
- exp_avg: torch.Tensor = state["exp_avg"]
208
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
209
-
210
- exp_avg.lerp_(t_ema1, 1-beta1)
211
-
212
- if z1_projected is None:
213
- exp_avg_sq.mul_(beta2).addcmul_(*t_ema2s, value=1-beta2)
214
- else:
215
- assert z2_projected is not None
216
- exp_avg_sq.mul_(beta2).addcmul_(z1_projected, z2_projected, value=1-beta2)
217
-
218
- # project exponential moving averages if they are accumulated unprojected
219
- exp_avg_projected = exp_avg
220
- if z1_projected is not None:
221
- exp_avg_projected = project(exp_avg, state['Q'])
222
-
223
- exp_avg_sq_projected = exp_avg_sq
224
-
225
- denom = exp_avg_sq_projected.sqrt().add_(eps)
226
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
227
-
228
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
229
- # to the original space
230
- update = exp_avg_projected / denom
231
- if z1_projected is not None:
232
- update = project_back(update, state["Q"])
233
-
234
- if setting['bias_correction']:
235
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
236
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
237
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
238
- elif alpha is not None:
239
- update *= alpha
240
-
241
- if merge_small:
242
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
243
-
244
- updates.append(update)
245
- state["step"] += 1
246
-
247
- # Update is done after the gradient step to avoid using current gradients in the projection.
248
- if state['GG'] is not None:
249
- update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
250
- if state['step'] % setting['precond_freq'] == 0:
251
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
252
-
253
- return updates
@@ -1,118 +0,0 @@
1
- from operator import itemgetter
2
- from functools import partial
3
-
4
- import torch
5
-
6
- from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
- from ..step_size.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
-
17
-
18
- def adadam_(
19
- tensors: TensorList,
20
- exp_avg_: TensorList,
21
- exp_avg_sq_: TensorList,
22
- exp_avg_qu_: TensorList,
23
- alpha: float | NumberList,
24
- beta1: float | NumberList,
25
- beta2: float | NumberList,
26
- precond_beta: float | NumberList,
27
- eps: float | NumberList,
28
- step: int,
29
- pow: float = 2,
30
- debiased: bool = True,
31
- max_exp_avg_sq_: TensorList | None = None,
32
- max_exp_avg_qu_: TensorList | None = None,
33
- params_: TensorList | None = None,
34
- ):
35
- """Returns new tensors or updates params in-place."""
36
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
37
-
38
- sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
39
- debiased=False,step=step,pow=pow)
40
- sqrt_exp_avg_qu = sqrt_ema_sq_(tensors/(sqrt_exp_avg_sq+1e-8), exp_avg_sq_=exp_avg_qu_,
41
- beta=precond_beta,max_exp_avg_sq_=max_exp_avg_qu_, debiased=False,step=step,pow=pow)
42
-
43
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
44
-
45
- # params is None, return update
46
- if params_ is None: return (exp_avg_ / sqrt_exp_avg_qu.add_(eps)).lazy_mul(alpha)
47
-
48
- # update params in-place
49
- params_.addcdiv_(exp_avg_, sqrt_exp_avg_qu.add_(eps), -alpha)
50
- return None
51
-
52
- class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner.
54
-
55
- Verdict: I haven't tested this yet.
56
-
57
- .. warning::
58
- Experimental.
59
- """
60
- def __init__(
61
- self,
62
- beta1: float = 0.9,
63
- beta2: float = 0.999,
64
- precond_beta: float = 0.999,
65
- eps: float = 1e-8,
66
- amsgrad: bool = False,
67
- alpha: float = 1.,
68
- pow: float = 2,
69
- debiased: bool = True,
70
- ):
71
- defaults=dict(beta1=beta1,beta2=beta2,precond_beta=precond_beta,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
72
- super().__init__(defaults)
73
- self.getter = itemgetter('amsgrad','pow','debiased')
74
-
75
- @torch.no_grad
76
- def step(self, var):
77
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
78
- params = var.params
79
-
80
- beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
81
- amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
82
-
83
- if amsgrad:
84
- exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
85
- else:
86
- exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
87
- max_exp_avg_sq = None
88
- max_exp_avg_qu = None
89
-
90
- # if this is last module, update parameters in-place with slightly more efficient addcdiv_
91
- if var.is_last:
92
- if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
93
- passed_params = TensorList(var.params)
94
- var.stop = True
95
- var.skip_update = True
96
-
97
- else:
98
- passed_params = None
99
-
100
- var.update = adadam_(
101
- tensors=TensorList(var.get_update()),
102
- exp_avg_=exp_avg,
103
- exp_avg_sq_=exp_avg_sq,
104
- exp_avg_qu_=exp_avg_qu,
105
- alpha=alpha,
106
- beta1=beta1,
107
- beta2=beta2,
108
- precond_beta=precond_beta,
109
- eps=eps,
110
- step=step,
111
- pow=pow,
112
- debiased=debiased,
113
- max_exp_avg_sq_=max_exp_avg_sq,
114
- max_exp_avg_qu_=max_exp_avg_qu,
115
- params_=passed_params,
116
- )
117
-
118
- return var
@@ -1,131 +0,0 @@
1
- from operator import itemgetter
2
- from functools import partial
3
-
4
- import torch
5
-
6
- from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
- from ..step_size.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
-
17
-
18
- def adamy_(
19
- p: TensorList,
20
- p_prev: TensorList,
21
- g: TensorList,
22
- g_prev: TensorList,
23
- exp_avg_: TensorList,
24
- exp_avg_sq_: TensorList,
25
- alpha: float | NumberList,
26
- beta1: float | NumberList,
27
- beta2: float | NumberList,
28
- eps: float | NumberList,
29
- step: int,
30
- pow: float = 2,
31
- debiased: bool = True,
32
- max_exp_avg_sq_: TensorList | None = None,
33
- params_: TensorList | None = None,
34
- ):
35
- """Returns new tensors or updates params in-place."""
36
- if step == 1:
37
- p_prev.copy_(p)
38
- g_prev.copy_(g)
39
-
40
- update = g.clip(-0.1,0.1).lazy_mul_(alpha)
41
- if params_ is None: return update
42
- params_.sub_(update)
43
- return None
44
-
45
- s = p-p_prev
46
- y = (g-g_prev).div_(s.global_vector_norm().clip(min=1e-8))
47
- p_prev.copy_(p)
48
- g_prev.copy_(g)
49
-
50
- exp_avg_ = ema_(g, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
51
-
52
- sqrt_exp_avg_sq = sqrt_ema_sq_(y, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
53
- debiased=False,step=step,pow=pow)
54
-
55
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
56
-
57
- # params is None, return update
58
- if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
59
-
60
- # update params in-place
61
- params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
62
- return None
63
-
64
- class AdamY(Module):
65
- """Adam but uses scaled gradient differences for second momentum.
66
-
67
- Verdict: I haven't tested this yet.
68
-
69
- .. warning::
70
- Experimental.
71
- """
72
- def __init__(
73
- self,
74
- beta1: float = 0.9,
75
- beta2: float = 0.999,
76
- eps: float = 1e-8,
77
- amsgrad: bool = False,
78
- alpha: float = 1.,
79
- pow: float = 2,
80
- debiased: bool = True,
81
- ):
82
- defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
83
- super().__init__(defaults)
84
- self.getter = itemgetter('amsgrad','pow','debiased')
85
-
86
- @torch.no_grad
87
- def step(self, var):
88
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
89
-
90
- beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
91
- amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
92
-
93
- if amsgrad:
94
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
95
- else:
96
- exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
97
- max_exp_avg_sq = None
98
-
99
- # if this is last module, update parameters in-place with slightly more efficient addcdiv_
100
- if var.is_last:
101
- if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
102
- passed_params = TensorList(var.params)
103
- var.stop = True
104
- var.skip_update = True
105
-
106
- else:
107
- passed_params = None
108
-
109
- p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
110
- g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
111
-
112
-
113
- var.update = adamy_(
114
- p=TensorList(var.params),
115
- p_prev=p_prev,
116
- g=TensorList(var.get_update()),
117
- g_prev=g_prev,
118
- exp_avg_=exp_avg,
119
- exp_avg_sq_=exp_avg_sq,
120
- alpha=alpha,
121
- beta1=beta1,
122
- beta2=beta2,
123
- eps=eps,
124
- step=step,
125
- pow=pow,
126
- debiased=debiased,
127
- max_exp_avg_sq_=max_exp_avg_sq,
128
- params_=passed_params,
129
- )
130
-
131
- return var
@@ -1,149 +0,0 @@
1
- from operator import itemgetter
2
- from functools import partial
3
- import math
4
- import torch
5
-
6
- from ...core import Module, Target, Transform, apply_transform, Chainable
7
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
- from ..step_size.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
-
17
-
18
- def _lambertw_newton_raphson(x: TensorList, iterations=5):
19
- # z = torch.zeros_like(x)
20
- # mask_neg = x < 0
21
- # mask_pos = ~mask_neg
22
-
23
- # z[mask_pos] = torch.log(x[mask_pos] + 1.0)
24
-
25
- # x_neg = x[mask_neg]
26
- # z_neg = -1.0 + torch.sqrt(2.0 * (1.0 + math.e * x_neg))
27
- # z[mask_neg] = z_neg
28
-
29
- # x is always positive
30
- z = (x+1).log_()
31
- for _ in range(iterations):
32
- exp_z = z.exp()
33
- numerator = z * exp_z - x
34
- denominator = exp_z * (z + 1.0) + 1e-8
35
- delta = numerator / denominator
36
- z -= delta
37
- return z
38
-
39
- # https://github.com/gmgeorg/torchlambertw/blob/main/torchlambertw/special.py
40
- def _lambertw_winitzki(x: TensorList):
41
- x_log1p = x.log1p()
42
- return x_log1p * (1.0 - x_log1p.log1p() / (2.0 + x_log1p))
43
-
44
-
45
- def adam_lambertw_(
46
- tensors: TensorList,
47
- exp_avg_: TensorList,
48
- exp_avg_xpx_: TensorList,
49
- alpha: float | NumberList,
50
- beta1: float | NumberList,
51
- beta2: float | NumberList,
52
- eps: float | NumberList,
53
- step: int,
54
- pow: float = 2,
55
- debiased: bool = True,
56
- max_exp_avg_xpx_: TensorList | None = None,
57
- iterations: int | None = 5,
58
-
59
- # inner args
60
- inner: Module | None = None,
61
- params: list[torch.Tensor] | None = None,
62
- grads: list[torch.Tensor] | None = None,
63
- ):
64
- """Returns new tensors."""
65
- tensors_abs = tensors.abs().clip_(max=20)
66
- tensors_xpx = tensors_abs.pow_(tensors_abs)
67
- exp_avg_xpx_.lerp_(tensors_xpx, 1-beta2)
68
-
69
- if max_exp_avg_xpx_ is not None:
70
- max_exp_avg_xpx_.maximum_(exp_avg_xpx_)
71
- exp_avg_xpx_ = max_exp_avg_xpx_
72
-
73
- if inner is not None:
74
- assert params is not None
75
- tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
76
-
77
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
78
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
79
-
80
- if iterations is None or iterations < 1: exp_avg_xpx_ = _lambertw_winitzki(exp_avg_xpx_)
81
- else: exp_avg_xpx_ = _lambertw_newton_raphson(exp_avg_xpx_, iterations)
82
-
83
- return (exp_avg_.lazy_mul(alpha) / exp_avg_xpx_.add_(eps))
84
-
85
- class AdamLambertW(Transform):
86
- """Adam but uses abs x^x and LambertW instead of square and sqrt.
87
- The gradient will be clipped to 20 because float32 which you have to use otherwise you're PC will explode.
88
-
89
- Args:
90
- beta1 (float, optional): momentum. Defaults to 0.9.
91
- beta2 (float, optional): second momentum. Defaults to 0.999.
92
- eps (float, optional): epsilon. Defaults to 1e-8.
93
- alpha (float, optional): learning rate. Defaults to 1.
94
- amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
95
- pow (float, optional): power used in second momentum power and root. Defaults to 2.
96
- debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
97
- iterations (int, optional): 0 or None means Winitzki approximation otherwise number of newton raphson iterations.
98
- """
99
- def __init__(
100
- self,
101
- beta1: float = 0.9,
102
- beta2: float = 0.999,
103
- eps: float = 1e-8,
104
- amsgrad: bool = False,
105
- alpha: float = 1.,
106
- pow: float = 2,
107
- debiased: bool = True,
108
- iterations: int | None = 5,
109
- inner: Chainable | None = None
110
- ):
111
- defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased, iterations=iterations)
112
- super().__init__(defaults, uses_grad=False)
113
-
114
- if inner is not None: self.set_child('inner', inner)
115
-
116
- @torch.no_grad
117
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
118
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
119
-
120
- beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
121
- amsgrad,pow,debiased,iterations = itemgetter('amsgrad','pow','debiased','iterations')(settings[0])
122
-
123
- if amsgrad:
124
- exp_avg, exp_avg_xpx, max_exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', 'max_exp_avg_xpx', cls=TensorList)
125
- else:
126
- exp_avg, exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', cls=TensorList)
127
- max_exp_avg_xpx = None
128
-
129
-
130
- return adam_lambertw_(
131
- tensors=TensorList(tensors),
132
- exp_avg_=exp_avg,
133
- exp_avg_xpx_=exp_avg_xpx,
134
- alpha=alpha,
135
- beta1=beta1,
136
- beta2=beta2,
137
- eps=eps,
138
- step=step,
139
- pow=pow,
140
- debiased=debiased,
141
- max_exp_avg_xpx_=max_exp_avg_xpx,
142
- iterations=iterations,
143
-
144
- # inner args
145
- inner=self.children.get("inner", None),
146
- params=params,
147
- grads=grads,
148
-
149
- )