torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,356 @@
1
+ from operator import itemgetter
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from ...core import (
6
+ Chainable,
7
+ Module,
8
+ Target,
9
+ TensorwiseTransform,
10
+ Transform,
11
+ Var,
12
+ apply_transform,
13
+ )
14
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
15
+ from ...utils.linalg import matrix_power_eigh
16
+ from ..functional import add_power_, lerp_power_, root, epsilon_step_size
17
+ from ...utils.linalg.linear_operator import Dense
18
+
19
+ def adagrad_(
20
+ tensors_: TensorList,
21
+ sq_sum_: TensorList,
22
+ alpha: float | NumberList,
23
+ lr_decay: float | NumberList,
24
+ eps: float | NumberList,
25
+ step: int,
26
+ pow: float = 2,
27
+ use_sqrt: bool = True,
28
+ divide: bool = False,
29
+
30
+ decay: float | None = None,
31
+ beta: float | None = None,
32
+
33
+ # inner args
34
+ inner: Module | None = None,
35
+ params: list[torch.Tensor] | None = None,
36
+ grads: list[torch.Tensor] | None = None,
37
+ ):
38
+ """returns `tensors_`"""
39
+ clr = alpha / (1 + step * lr_decay)
40
+
41
+ if beta is None or step == 1: sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
42
+ else: sq_sum_ = lerp_power_(tensors_, exp_avg_pow_=sq_sum_, beta=beta, pow=pow)
43
+ if decay is not None:
44
+ sq_sum_.mul_(1-decay)
45
+
46
+ if inner is not None:
47
+ assert params is not None
48
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
49
+
50
+ if divide: sq_sum_ = sq_sum_ / max(step, 1)
51
+
52
+ if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
53
+ else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
54
+
55
+ return tensors_
56
+
57
+
58
+
59
+ class Adagrad(Transform):
60
+ """Adagrad, divides by sum of past squares of gradients.
61
+
62
+ This implementation is identical to ``torch.optim.Adagrad``.
63
+
64
+ Args:
65
+ lr_decay (float, optional): learning rate decay. Defaults to 0.
66
+ initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
67
+ eps (float, optional): division epsilon. Defaults to 1e-10.
68
+ alpha (float, optional): step size. Defaults to 1.
69
+ pow (float, optional): power for gradients and accumulator root. Defaults to 2.
70
+ use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
71
+ inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
72
+ """
73
+ def __init__(
74
+ self,
75
+ lr_decay: float = 0,
76
+ initial_accumulator_value: float = 0,
77
+ eps: float = 1e-10,
78
+ alpha: float = 1,
79
+ pow: float = 2,
80
+ use_sqrt: bool = True,
81
+ divide: bool=False,
82
+ beta:float | None = None,
83
+ decay: float | None = None,
84
+ inner: Chainable | None = None,
85
+ ):
86
+ defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
87
+ eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
88
+ super().__init__(defaults=defaults, uses_grad=False)
89
+
90
+ if inner is not None:
91
+ self.set_child('inner', inner)
92
+
93
+ @torch.no_grad
94
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
95
+ tensors = TensorList(tensors)
96
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
97
+
98
+ lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
99
+
100
+ pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
101
+
102
+ sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
103
+
104
+ # initialize accumulator on 1st step
105
+ if step == 1:
106
+ sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
107
+
108
+ return adagrad_(
109
+ tensors,
110
+ sq_sum_=sq_sum,
111
+ alpha=alpha,
112
+ lr_decay=lr_decay,
113
+ eps=eps,
114
+ step=step,
115
+ pow=pow,
116
+ use_sqrt=use_sqrt,
117
+ divide=divide,
118
+
119
+ beta = self.defaults["beta"],
120
+ decay = self.defaults["decay"],
121
+ # inner args
122
+ inner=self.children.get("inner", None),
123
+ params=params,
124
+ grads=grads,
125
+ )
126
+
127
+
128
+ def lerp(start, end, weight):
129
+ return start + weight * (end - start)
130
+
131
+ def adagrad_norm_(
132
+ tensors_: TensorList,
133
+ accumulator: float | torch.Tensor,
134
+ alpha: float | NumberList,
135
+ lr_decay: float | NumberList,
136
+ eps: float | NumberList,
137
+ step: int,
138
+ use_sqrt: bool = True,
139
+ divide: bool = False,
140
+
141
+ decay: float | None = None,
142
+ beta: float | None = None,
143
+
144
+ # inner args
145
+ inner: Module | None = None,
146
+ params: list[torch.Tensor] | None = None,
147
+ grads: list[torch.Tensor] | None = None,
148
+ ):
149
+ """returns `tensors_`"""
150
+ clr = alpha / (1 + step * lr_decay)
151
+
152
+ gg = tensors_.dot(tensors_)
153
+
154
+ if beta is None or step == 1: accumulator += gg
155
+ else: accumulator = lerp(accumulator, gg, 1-beta)
156
+
157
+ if decay is not None:
158
+ accumulator *= 1-decay
159
+
160
+ if inner is not None:
161
+ assert params is not None
162
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
163
+
164
+ if divide: accumulator = accumulator / max(step, 1)
165
+
166
+ if use_sqrt: tensors_.div_(eps + accumulator.sqrt()).mul_(clr)
167
+ else: tensors_.div_(eps + accumulator).mul_(clr)
168
+
169
+ return tensors_, accumulator
170
+
171
+ class AdagradNorm(Transform):
172
+ """Adagrad-Norm, divides by sum of past means of squares of gradients.
173
+
174
+ Args:
175
+ lr_decay (float, optional): learning rate decay. Defaults to 0.
176
+ initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
177
+ eps (float, optional): division epsilon. Defaults to 1e-10.
178
+ alpha (float, optional): step size. Defaults to 1.
179
+ pow (float, optional): power for gradients and accumulator root. Defaults to 2.
180
+ use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
181
+ inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
182
+ """
183
+ def __init__(
184
+ self,
185
+ lr_decay: float = 0,
186
+ initial_accumulator_value: float = 0,
187
+ eps: float = 1e-10,
188
+ alpha: float = 1,
189
+ pow: float = 2,
190
+ use_sqrt: bool = True,
191
+ divide: bool=False,
192
+ beta:float | None = None,
193
+ decay: float | None = None,
194
+ inner: Chainable | None = None,
195
+ ):
196
+ defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
197
+ eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
198
+ super().__init__(defaults=defaults, uses_grad=False)
199
+
200
+ if inner is not None:
201
+ self.set_child('inner', inner)
202
+
203
+ @torch.no_grad
204
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
205
+ tensors = TensorList(tensors)
206
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
207
+ lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
208
+
209
+ use_sqrt, divide, initial_accumulator_value = itemgetter('use_sqrt', 'divide', "initial_accumulator_value")(settings[0])
210
+
211
+ accumulator = self.global_state.get("accumulator", initial_accumulator_value)
212
+
213
+ d, self.global_state["accumulator"] = adagrad_norm_(
214
+ tensors,
215
+ accumulator=accumulator,
216
+ alpha=alpha,
217
+ lr_decay=lr_decay,
218
+ eps=eps,
219
+ step=step,
220
+ use_sqrt=use_sqrt,
221
+ divide=divide,
222
+
223
+ beta = self.defaults["beta"],
224
+ decay = self.defaults["decay"],
225
+ # inner args
226
+ inner=self.children.get("inner", None),
227
+ params=params,
228
+ grads=grads,
229
+ )
230
+
231
+ return d
232
+
233
+
234
+ class FullMatrixAdagrad(TensorwiseTransform):
235
+ """Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
236
+
237
+ Note:
238
+ A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.
239
+
240
+ Args:
241
+ beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
242
+ decay (float | None, optional): decay for gradient outer product accumulators. Defaults to None.
243
+ sqrt (bool, optional): whether to take the square root of the accumulator. Defaults to True.
244
+ concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
245
+ precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
246
+ init (Literal[str], optional):
247
+ how to initialize the accumulator.
248
+ - "identity" - with identity matrix (default).
249
+ - "zeros" - with zero matrix.
250
+ - "ones" - with matrix of ones.
251
+ -"GGT" - with the first outer product
252
+ divide (bool, optional): whether to divide the accumulator by number of gradients in it. Defaults to False.
253
+ inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.
254
+
255
+ ## Examples:
256
+
257
+ Plain full-matrix adagrad
258
+ ```python
259
+ opt = tz.Modular(
260
+ model.parameters(),
261
+ tz.m.FullMatrixAdagrd(),
262
+ tz.m.LR(1e-2),
263
+ )
264
+ ```
265
+
266
+ Full-matrix RMSprop
267
+ ```python
268
+ opt = tz.Modular(
269
+ model.parameters(),
270
+ tz.m.FullMatrixAdagrad(beta=0.99),
271
+ tz.m.LR(1e-2),
272
+ )
273
+ ```
274
+
275
+ Full-matrix Adam
276
+ ```python
277
+ opt = tz.Modular(
278
+ model.parameters(),
279
+ tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
280
+ tz.m.Debias(0.9, 0.999),
281
+ tz.m.LR(1e-2),
282
+ )
283
+ ```
284
+ """
285
+ def __init__(
286
+ self,
287
+ beta: float | None = None,
288
+ decay: float | None = None,
289
+ sqrt: bool = True,
290
+ concat_params=True,
291
+ precond_freq: int = 1,
292
+ init: Literal["identity", "zeros", "ones", "GGT"] = "identity",
293
+ reg: float = 1e-12,
294
+ divide: bool = False,
295
+ inner: Chainable | None = None,
296
+ ):
297
+ defaults = dict(beta=beta, decay=decay, sqrt=sqrt, precond_freq=precond_freq, init=init, divide=divide, reg=reg)
298
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner,)
299
+
300
+ @torch.no_grad
301
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
302
+ G = tensor.ravel()
303
+ GG = torch.outer(G, G)
304
+ decay = setting['decay']
305
+ beta = setting['beta']
306
+ init = setting['init']
307
+
308
+ if 'GG' not in state:
309
+ if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
310
+ elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
311
+ elif init == 'ones': state['GG'] = torch.ones_like(GG)
312
+ elif init == 'GGT': state['GG'] = GG.clone()
313
+ else: raise ValueError(init)
314
+ if decay is not None: state['GG'].mul_(decay)
315
+
316
+ if beta is not None: state['GG'].lerp_(GG, 1-beta)
317
+ else: state['GG'].add_(GG)
318
+ state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
319
+
320
+ @torch.no_grad
321
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
322
+ step = state.get('step', 0)
323
+ state['step'] = step + 1
324
+
325
+ GG: torch.Tensor = state['GG']
326
+ sqrt = setting['sqrt']
327
+ divide = setting['divide']
328
+ precond_freq = setting['precond_freq']
329
+ reg = setting['reg']
330
+
331
+ if divide: GG = GG/state.get('i', 1)
332
+
333
+ if reg != 0:
334
+ GG = GG + torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype).mul_(reg)
335
+
336
+ if tensor.numel() == 1:
337
+ GG = GG.squeeze()
338
+ if sqrt: return tensor / GG.sqrt()
339
+ return tensor / GG
340
+
341
+ try:
342
+ if sqrt:
343
+ if "B" not in state or step % precond_freq == 0:
344
+ B = state["B"] = matrix_power_eigh(GG, -1/2)
345
+ else:
346
+ B = state["B"]
347
+
348
+ else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
349
+
350
+ except torch.linalg.LinAlgError:
351
+ # fallback to diagonal AdaGrad
352
+ denom = GG.diagonal()
353
+ if sqrt: denom = denom.sqrt()
354
+ return tensor.div_(denom + max(reg, 1e-12))
355
+
356
+ return (B @ tensor.ravel()).view_as(tensor)
@@ -0,0 +1,224 @@
1
+ import math
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, Target, Transform, apply_transform
7
+ from ...utils import NumberList, TensorList, as_tensorlist
8
+ from ..functional import debiased_step_size
9
+
10
+ def _full_average(hvp: torch.Tensor):
11
+ if hvp.ndim >= 3: # Conv kernel
12
+ return torch.mean(hvp.abs(), dim=[2, *range(3,hvp.ndim)], keepdim=True)
13
+ return hvp
14
+
15
+ def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
16
+ """averages x over first dimension in blocks"""
17
+ if enable and x.ndim >= 2:
18
+ if math.prod(x.shape[1:]) <= 1: return x
19
+ if block_size is None: return _full_average(x)
20
+ size = x.size(0)
21
+
22
+ n_blocks = size // block_size
23
+ if n_blocks <= 1: return x.abs().mean(0, keepdim = True)
24
+
25
+ n_remaining = size - n_blocks * block_size
26
+ remaining = None
27
+ if n_remaining > 0:
28
+ remaining = x[-n_remaining:].abs().mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
29
+ x = x[:-n_remaining]
30
+
31
+ x = x.view(block_size, n_blocks, *x.shape[1:])
32
+ x_mean = x.abs().mean(0).repeat_interleave(block_size, 0)
33
+
34
+ if remaining is None: return x_mean
35
+ return torch.cat([x_mean, remaining], 0)
36
+
37
+ return x
38
+
39
+
40
+ def _rademacher_like(tensor, p = 0.5, generator = None):
41
+ """p is probability of a 1, other values will be -1."""
42
+ return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
43
+
44
+ def adahessian(
45
+ tensors: TensorList,
46
+ D: TensorList | None,
47
+ exp_avg_: TensorList,
48
+ D_exp_avg_sq_: TensorList,
49
+ beta1: float | NumberList,
50
+ beta2: float | NumberList,
51
+ update_freq: int,
52
+ eps: float | NumberList,
53
+ hessian_power: float | NumberList,
54
+ step: int,
55
+ ):
56
+ # momentum
57
+ exp_avg_.lerp_(tensors, 1-beta1)
58
+
59
+ # update preconditioner
60
+ if step % update_freq == 0:
61
+ assert D is not None
62
+ D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
63
+
64
+ else:
65
+ assert D is None
66
+
67
+
68
+ denom = D_exp_avg_sq_.sqrt().pow_(hessian_power).add_(eps)
69
+ num = exp_avg_ * debiased_step_size(step+1, beta1, beta2)
70
+
71
+ return num.div_(denom)
72
+
73
+
74
+ class AdaHessian(Module):
75
+ """AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
76
+
77
+ This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
78
+
79
+ Notes:
80
+ - In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.
81
+
82
+ - If you are using gradient estimators or reformulations, set ``hvp_method`` to "forward" or "central".
83
+
84
+ - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
85
+
86
+ Args:
87
+ beta1 (float, optional): first momentum. Defaults to 0.9.
88
+ beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
89
+ averaging (bool, optional):
90
+ whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
91
+ This can be set per-parameter in param groups.
92
+ block_size (int, optional):
93
+ size of block in the block-diagonal averaging.
94
+ update_freq (int, optional):
95
+ frequency of updating hessian diagonal estimate via a hessian-vector product.
96
+ This value can be increased to reduce computational cost. Defaults to 1.
97
+ eps (float, optional):
98
+ division stability epsilon. Defaults to 1e-8.
99
+ hvp_method (str, optional):
100
+ Determines how Hessian-vector products are evaluated.
101
+
102
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
103
+ This requires creating a graph for the gradient.
104
+ - ``"forward"``: Use a forward finite difference formula to
105
+ approximate the HVP. This requires one extra gradient evaluation.
106
+ - ``"central"``: Use a central finite difference formula for a
107
+ more accurate HVP approximation. This requires two extra
108
+ gradient evaluations.
109
+ Defaults to "autograd".
110
+ fd_h (float, optional): finite difference step size if ``hvp_method`` is "forward" or "central". Defaults to 1e-3.
111
+ n_samples (int, optional):
112
+ number of hessian-vector products with random vectors to evaluate each time when updating
113
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
114
+ seed (int | None, optional): seed for random vectors. Defaults to None.
115
+ inner (Chainable | None, optional):
116
+ Inner module. If this is specified, operations are performed in the following order.
117
+ 1. compute hessian diagonal estimate.
118
+ 2. pass inputs to ``inner``.
119
+ 3. momentum and preconditioning are applied to the ouputs of ``inner``.
120
+
121
+ ## Examples:
122
+
123
+ Using AdaHessian:
124
+
125
+ ```python
126
+ opt = tz.Modular(
127
+ model.parameters(),
128
+ tz.m.AdaHessian(),
129
+ tz.m.LR(0.1)
130
+ )
131
+ ```
132
+
133
+ AdaHessian preconditioner can be applied to any other module by passing it to the ``inner`` argument.
134
+ Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
135
+ AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
136
+ ```python
137
+ opt = tz.Modular(
138
+ model.parameters(),
139
+ tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
140
+ tz.m.LR(0.1)
141
+ )
142
+ ```
143
+
144
+ """
145
+ def __init__(
146
+ self,
147
+ beta1: float = 0.9,
148
+ beta2: float = 0.999,
149
+ averaging: bool = True,
150
+ block_size: int | None = None,
151
+ update_freq: int = 1,
152
+ eps: float = 1e-8,
153
+ hessian_power: float = 1,
154
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
155
+ fd_h: float = 1e-3,
156
+ n_samples = 1,
157
+ seed: int | None = None,
158
+ inner: Chainable | None = None
159
+ ):
160
+ defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hessian_power=hessian_power, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
161
+ super().__init__(defaults)
162
+
163
+ if inner is not None:
164
+ self.set_child('inner', inner)
165
+
166
+ @torch.no_grad
167
+ def step(self, var):
168
+ params = var.params
169
+ settings = self.settings[params[0]]
170
+ hvp_method = settings['hvp_method']
171
+ fd_h = settings['fd_h']
172
+ update_freq = settings['update_freq']
173
+ n_samples = settings['n_samples']
174
+
175
+ seed = settings['seed']
176
+ generator = self.get_generator(params[0].device, seed)
177
+
178
+ beta1, beta2, eps, averaging, block_size, hessian_power = self.get_settings(params,
179
+ 'beta1', 'beta2', 'eps', 'averaging', 'block_size', "hessian_power", cls=NumberList)
180
+
181
+ exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
182
+
183
+ step = self.global_state.get('step', 0)
184
+ self.global_state['step'] = step + 1
185
+
186
+ closure = var.closure
187
+ assert closure is not None
188
+
189
+ D = None
190
+ if step % update_freq == 0:
191
+
192
+ rgrad=None
193
+ for i in range(n_samples):
194
+ u = [_rademacher_like(p, generator=generator) for p in params]
195
+
196
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
197
+ h=fd_h, normalize=True, retain_grad=i < n_samples-1)
198
+ Hvp = tuple(Hvp)
199
+
200
+ if D is None: D = Hvp
201
+ else: torch._foreach_add_(D, Hvp)
202
+
203
+ assert D is not None
204
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
205
+
206
+ D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
207
+
208
+ update = var.get_update()
209
+ if 'inner' in self.children:
210
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
211
+
212
+ var.update = adahessian(
213
+ tensors=TensorList(update),
214
+ D=TensorList(D) if D is not None else None,
215
+ exp_avg_=exp_avg,
216
+ D_exp_avg_sq_=D_exp_avg_sq,
217
+ beta1=beta1,
218
+ beta2=beta2,
219
+ update_freq=update_freq,
220
+ eps=eps,
221
+ hessian_power=hessian_power,
222
+ step=step,
223
+ )
224
+ return var
@@ -10,9 +10,6 @@ from ..functional import (
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..lr.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
13
 
17
14
 
18
15
  def adam_(
@@ -33,7 +30,7 @@ def adam_(
33
30
  params: list[torch.Tensor] | None = None,
34
31
  grads: list[torch.Tensor] | None = None,
35
32
  ):
36
- """Returns new tensors or updates params in-place."""
33
+ """Returns new tensors."""
37
34
  sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
38
35
  debiased=False,step=step,pow=pow)
39
36
 
@@ -43,11 +40,12 @@ def adam_(
43
40
 
44
41
  exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
45
42
  if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
46
- return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
43
+ return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
47
44
 
48
45
  class Adam(Transform):
49
- """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
50
- pytorch in that debiasing is applied after adding epsilon.
46
+ """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
47
+
48
+ This implementation is identical to :code:`torch.optim.Adam`.
51
49
 
52
50
  Args:
53
51
  beta1 (float, optional): momentum. Defaults to 0.9.
@@ -75,7 +73,7 @@ class Adam(Transform):
75
73
  if inner is not None: self.set_child('inner', inner)
76
74
 
77
75
  @torch.no_grad
78
- def apply(self, tensors, params, grads, loss, states, settings):
76
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
79
77
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
80
78
 
81
79
  beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)