torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,96 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
+
6
+ def adan_(
7
+ g: TensorList,
8
+ g_prev_: TensorList,
9
+ m_: TensorList, # exponential moving average
10
+ v_: TensorList, # exponential moving average of gradient differences
11
+ n_: TensorList, # kinda like squared momentum
12
+ beta1: float | NumberList,
13
+ beta2: float | NumberList,
14
+ beta3: float | NumberList,
15
+ eps: float | NumberList,
16
+ step: int,
17
+ ):
18
+ """Returns new tensors"""
19
+ m_.lerp_(g, 1 - beta1)
20
+
21
+ if step == 1:
22
+ term = g
23
+ else:
24
+ diff = g - g_prev_
25
+ v_.lerp_(diff, 1 - beta2)
26
+ term = g + beta2 * diff
27
+
28
+ n_.mul_(beta3).addcmul_(term, term, value=(1 - beta3))
29
+
30
+ m = m_ / (1.0 - beta1**step)
31
+ v = v_ / (1.0 - beta2**step)
32
+ n = n_ / (1.0 - beta3**step)
33
+
34
+ denom = n.sqrt_().add_(eps)
35
+ num = m + beta2 * v
36
+
37
+ update = num.div_(denom)
38
+ g_prev_.copy_(g)
39
+
40
+ return update
41
+
42
+
43
+
44
+ class Adan(Transform):
45
+ """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
46
+
47
+ Args:
48
+ beta1 (float, optional): momentum. Defaults to 0.98.
49
+ beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
50
+ beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
51
+ eps (float, optional): epsilon. Defaults to 1e-8.
52
+ use_n_prev (bool, optional):
53
+ whether to use previous gradient differences momentum.
54
+
55
+ Example:
56
+ ```python
57
+ opt = tz.Modular(
58
+ model.parameters(),
59
+ tz.m.Adan(),
60
+ tz.m.LR(1e-3),
61
+ )
62
+ Reference:
63
+ Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
64
+ """
65
+ def __init__(
66
+ self,
67
+ beta1: float = 0.98,
68
+ beta2: float = 0.92,
69
+ beta3: float = 0.99,
70
+ eps: float = 1e-8,
71
+ ):
72
+ defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps)
73
+ super().__init__(defaults, uses_grad=False)
74
+
75
+ @torch.no_grad
76
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
77
+ tensors = TensorList(tensors)
78
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
79
+
80
+ beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
81
+ g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
82
+
83
+ update = adan_(
84
+ g=tensors,
85
+ g_prev_=g_prev,
86
+ m_=m,
87
+ v_=v,
88
+ n_=n,
89
+ beta1=beta1,
90
+ beta2=beta2,
91
+ beta3=beta3,
92
+ eps=eps,
93
+ step=step,
94
+ )
95
+
96
+ return update
@@ -0,0 +1,54 @@
1
+ import torch
2
+ from ...core import Transform
3
+ from ...utils import TensorList, unpack_dicts, unpack_states
4
+
5
+
6
+ def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
7
+ if f - f_star <= torch.finfo(p[0].dtype).tiny * 2: return g
8
+
9
+ g_g = g.dot(g)
10
+ g_gp = g.dot(g_prev)
11
+ num = -(f - f_star) * g.dot(g_prev)
12
+ denom = (f_prev - f_star) * g_g + (f - f_star) * g_gp
13
+ m = num/denom
14
+
15
+ h = 2*(f - f_star) / g_g
16
+ return (1 + m) * h * g - m*(p-p_prev)
17
+
18
+
19
+ class AdaptiveHeavyBall(Transform):
20
+ """Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
21
+
22
+ This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
23
+
24
+ note:
25
+ The step size is determined by the algorithm, so learning rate modules shouldn't be used.
26
+
27
+ Args:
28
+ f_star (int, optional):
29
+ (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
30
+ """
31
+ def __init__(self, f_star: float = 0):
32
+ defaults = dict(f_star=f_star)
33
+ super().__init__(defaults, uses_grad=False, uses_loss=True)
34
+
35
+ @torch.no_grad
36
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
37
+ assert loss is not None
38
+ tensors = TensorList(tensors)
39
+ f_star = self.defaults['f_star']
40
+
41
+ f_prev = self.global_state.get('f_prev', None)
42
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
43
+
44
+ if f_prev is None:
45
+ self.global_state['f_prev'] = loss
46
+ h = 2*(loss - f_star) / tensors.dot(tensors)
47
+ return h * tensors
48
+
49
+ update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
50
+
51
+ self.global_state['f_prev'] = loss
52
+ p_prev.copy_(params)
53
+ g_prev.copy_(tensors)
54
+ return update
@@ -0,0 +1,54 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ from ...core import Transform
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
+
8
+ # i've verified, it is identical to official
9
+ # https://github.com/txping/AEGD/blob/master/aegd.py
10
+ def aegd_(f: torch.Tensor | float, g: TensorList, r_: TensorList, c:float|NumberList=1, eta:float|NumberList=0.1) -> TensorList:
11
+ v = g / (2 * (f + c)**0.5)
12
+ r_ /= 1 + (v ** 2).mul_(2*eta) # update energy
13
+ return 2*eta * r_*v # pyright:ignore[reportReturnType]
14
+
15
+ class AEGD(Transform):
16
+ """AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.
17
+
18
+ Note:
19
+ AEGD has a learning rate hyperparameter that can't really be removed from the update rule.
20
+ To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
21
+
22
+ Args:
23
+ eta (float, optional): step size. Defaults to 0.1.
24
+ c (float, optional): c. Defaults to 1.
25
+ beta3 (float, optional): thrid (squared) momentum. Defaults to 0.1.
26
+ eps (float, optional): epsilon. Defaults to 1e-8.
27
+ use_n_prev (bool, optional):
28
+ whether to use previous gradient differences momentum.
29
+ """
30
+ def __init__(
31
+ self,
32
+ lr: float = 0.1,
33
+ c: float = 1,
34
+ ):
35
+ defaults=dict(c=c,lr=lr)
36
+ super().__init__(defaults, uses_loss=True)
37
+
38
+ @torch.no_grad
39
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
+ assert loss is not None
41
+ tensors = TensorList(tensors)
42
+
43
+ c,lr=unpack_dicts(settings, 'c','lr', cls=NumberList)
44
+ r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss+c[0])**0.5), cls=TensorList)
45
+
46
+ update = aegd_(
47
+ f=loss,
48
+ g=tensors,
49
+ r_=r,
50
+ c=c,
51
+ eta=lr,
52
+ )
53
+
54
+ return update
@@ -0,0 +1,171 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Target, Transform, apply_transform
8
+ from ...utils import NumberList, TensorList, as_tensorlist
9
+
10
+
11
+ def esgd_(
12
+ tensors_: TensorList,
13
+ D: TensorList | None,
14
+ D_sq_acc_: TensorList,
15
+ damping: float | NumberList,
16
+ update_freq: int,
17
+ step: int,
18
+ i: int,
19
+ ):
20
+ # update preconditioner
21
+ if step % update_freq == 0:
22
+ assert D is not None
23
+ D_sq_acc_.addcmul_(D, D)
24
+ i += 1
25
+ else:
26
+ assert D is None
27
+
28
+ denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
29
+ return tensors_.div_(denom), i
30
+
31
+
32
+ class ESGD(Module):
33
+ """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
34
+
35
+ This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
36
+
37
+ .. note::
38
+ In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
39
+
40
+ .. note::
41
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
42
+
43
+ .. note::
44
+ This module requires a closure passed to the optimizer step,
45
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
46
+ The closure must accept a ``backward`` argument (refer to documentation).
47
+
48
+ Args:
49
+ damping (float, optional): added to denominator for stability. Defaults to 1e-4.
50
+ update_freq (int, optional):
51
+ frequency of updating hessian diagonal estimate via a hessian-vector product.
52
+ This value can be increased to reduce computational cost. Defaults to 20.
53
+ hvp_method (str, optional):
54
+ Determines how Hessian-vector products are evaluated.
55
+
56
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
57
+ This requires creating a graph for the gradient.
58
+ - ``"forward"``: Use a forward finite difference formula to
59
+ approximate the HVP. This requires one extra gradient evaluation.
60
+ - ``"central"``: Use a central finite difference formula for a
61
+ more accurate HVP approximation. This requires two extra
62
+ gradient evaluations.
63
+ Defaults to "autograd".
64
+ fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
65
+ n_samples (int, optional):
66
+ number of hessian-vector products with random vectors to evaluate each time when updating
67
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
68
+ seed (int | None, optional): seed for random vectors. Defaults to None.
69
+ inner (Chainable | None, optional):
70
+ Inner module. If this is specified, operations are performed in the following order.
71
+ 1. compute hessian diagonal estimate.
72
+ 2. pass inputs to :code:`inner`.
73
+ 3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
74
+
75
+ Examples:
76
+ Using ESGD:
77
+
78
+ .. code-block:: python
79
+
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.ESGD(),
83
+ tz.m.LR(0.1)
84
+ )
85
+
86
+ ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
87
+ ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
88
+
89
+ .. code-block:: python
90
+
91
+ opt = tz.Modular(
92
+ model.parameters(),
93
+ tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
94
+ tz.m.LR(0.1)
95
+ )
96
+
97
+ """
98
+ def __init__(
99
+ self,
100
+ damping: float = 1e-4,
101
+ update_freq: int = 20,
102
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
103
+ fd_h: float = 1e-3,
104
+ n_samples = 1,
105
+ seed: int | None = None,
106
+ inner: Chainable | None = None
107
+ ):
108
+ defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
109
+ super().__init__(defaults)
110
+
111
+ if inner is not None:
112
+ self.set_child('inner', inner)
113
+
114
+ @torch.no_grad
115
+ def step(self, var):
116
+ params = var.params
117
+ settings = self.settings[params[0]]
118
+ hvp_method = settings['hvp_method']
119
+ fd_h = settings['fd_h']
120
+ update_freq = settings['update_freq']
121
+ n_samples = settings['n_samples']
122
+
123
+ seed = settings['seed']
124
+ generator = None
125
+ if seed is not None:
126
+ if 'generator' not in self.global_state:
127
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
128
+ generator = self.global_state['generator']
129
+
130
+ damping = self.get_settings(params, 'damping', cls=NumberList)
131
+ D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
132
+ i = self.global_state.get('i', 0)
133
+
134
+ step = self.global_state.get('step', 0)
135
+ self.global_state['step'] = step + 1
136
+
137
+ closure = var.closure
138
+ assert closure is not None
139
+
140
+ D = None
141
+ if step % update_freq == 0:
142
+
143
+ rgrad=None
144
+ for j in range(n_samples):
145
+ u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
146
+
147
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
148
+ h=fd_h, normalize=True, retain_grad=j < n_samples-1)
149
+
150
+ if D is None: D = Hvp
151
+ else: torch._foreach_add_(D, Hvp)
152
+
153
+ assert D is not None
154
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
155
+
156
+ D = TensorList(D)
157
+
158
+ update = var.get_update()
159
+ if 'inner' in self.children:
160
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
161
+
162
+ var.update, self.global_state['i'] = esgd_(
163
+ tensors_=TensorList(update),
164
+ D=TensorList(D) if D is not None else None,
165
+ D_sq_acc_=D_sq_acc,
166
+ damping=damping,
167
+ update_freq=update_freq,
168
+ step=step,
169
+ i=i,
170
+ )
171
+ return var
@@ -28,7 +28,7 @@ class Lion(Transform):
28
28
  super().__init__(defaults, uses_grad=False)
29
29
 
30
30
  @torch.no_grad
31
- def apply(self, tensors, params, grads, loss, states, settings):
31
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
32
32
  beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
33
33
  exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
34
34
  return lion_(TensorList(tensors),exp_avg,beta1,beta2)
@@ -1,55 +1,57 @@
1
- from abc import ABC, abstractmethod
2
- import math
3
1
  from collections import deque
4
2
  from typing import Literal, Any
5
- import itertools
3
+ import warnings
6
4
 
7
5
  import torch
8
6
  from ...core import Chainable, TensorwiseTransform
9
- from ...utils.linalg.matrix_funcs import matrix_power_eigh
10
- from ...utils.linalg.svd import randomized_svd
11
- from ...utils.linalg.qr import qr_householder
12
7
 
13
- def spectral_update(history, damping, rdamping, true_damping: bool):
14
- M_hist = torch.stack(tuple(history), dim=1)
15
- device = M_hist.device
16
- M_hist = M_hist.cuda()
8
+ def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping):
9
+ if isinstance(history, torch.Tensor):
10
+ M = history
11
+ else:
12
+ M = torch.stack(tuple(history), dim=1)# / len(history)
13
+
14
+ MTM = M.T @ M
15
+ if damping != 0:
16
+ MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
17
17
 
18
18
  try:
19
- U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver='gesvda') # pylint:disable=not-callable
20
- U = U.to(device); S = S.to(device)
19
+ L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
20
+
21
+ tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
22
+ indices = L > tol
23
+ L = L[indices]
24
+ Q = Q[:, indices]
21
25
 
22
- if damping != 0 or rdamping != 0:
23
- if rdamping != 0: rdamping *= torch.linalg.vector_norm(S) # pylint:disable=not-callable
24
- Iu = damping + rdamping
25
- if true_damping:
26
- S.pow_(2)
27
- Iu **= 2
28
- S.add_(Iu)
29
- if true_damping: S.sqrt_()
26
+ U = (M @ Q) * L.rsqrt()
30
27
 
31
- return U, 1/S
28
+ if rdamping != 0:
29
+ rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
30
+ L.add_(rdamping)
31
+
32
+ return U, L
32
33
 
33
34
  except torch.linalg.LinAlgError:
34
35
  return None, None
35
36
 
36
- def spectral_apply(g: torch.Tensor, U: torch.Tensor, S_inv: torch.Tensor):
37
- Utg = (U.T @ g)*S_inv
38
- return U @ Utg
39
-
37
+ def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
38
+ Z = U.T @ g
39
+ return (U * L.rsqrt()) @ Z
40
40
 
41
41
  def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
42
42
  if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
43
43
  else:
44
- if state_[key].shape != value.shape: state_[key] = value
44
+ if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
45
45
  else: state_[key].lerp_(value, 1-beta)
46
46
 
47
- class SpectralPreconditioner(TensorwiseTransform):
47
+ class LMAdagrad(TensorwiseTransform):
48
48
  """
49
- The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate U (Uᵀg)/S.
50
- This is equivalent to full matrix Adagrad with accumulator initialized to zeros,
51
- except only recent :code:`history_size` gradients are used.
52
- However this doesn't require N^2 memory and is computationally less expensive than Shampoo.
49
+ Limited-memory full matrix Adagrad.
50
+
51
+ The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
52
+ But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
53
+
54
+ This is equivalent to full-matrix Adagrad on recent gradients.
53
55
 
54
56
  Args:
55
57
  history_size (int, optional): number of past gradients to store. Defaults to 10.
@@ -61,54 +63,81 @@ class SpectralPreconditioner(TensorwiseTransform):
61
63
  true_damping (bool, optional):
62
64
  If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
63
65
  U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
64
- S_beta (float | None, optional): momentum for 1/S (too unstable, don't use). Defaults to None.
66
+ L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
65
67
  interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
66
- concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to False.
67
- normalize (bool, optional): whether to normalize gradients, this doesn't work well so don't use it. Defaults to False.
68
- centralize (bool, optional): whether to centralize gradients, this doesn't work well so don't use it. Defaults to False.
68
+ concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
69
69
  inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
70
+
71
+ ## Examples:
72
+
73
+ Limited-memory Adagrad
74
+
75
+ ```python
76
+ optimizer = tz.Modular(
77
+ model.parameters(),
78
+ tz.m.LMAdagrad(),
79
+ tz.m.LR(0.1)
80
+ )
81
+ ```
82
+ Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
83
+
84
+ ```python
85
+ optimizer = tz.Modular(
86
+ model.parameters(),
87
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
88
+ tz.m.Debias(0.9, 0.999),
89
+ tz.m.LR(0.01)
90
+ )
91
+ ```
92
+
93
+ Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
94
+
95
+ ```python
96
+ optimizer = tz.Modular(
97
+ model.parameters(),
98
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
99
+ tz.m.Debias(0.9, 0.999),
100
+ tz.m.ClipNormByEMA(max_ema_growth=1.2),
101
+ tz.m.LR(0.01)
102
+ )
103
+ ```
104
+ Reference:
105
+ Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
70
106
  """
71
107
 
72
108
  def __init__(
73
109
  self,
74
- history_size: int = 10,
110
+ history_size: int = 100,
75
111
  update_freq: int = 1,
76
112
  damping: float = 1e-4,
77
113
  rdamping: float = 0,
78
114
  order: int = 1,
79
115
  true_damping: bool = True,
80
116
  U_beta: float | None = None,
81
- S_beta: float | None = None,
117
+ L_beta: float | None = None,
82
118
  interval: int = 1,
83
- concat_params: bool = False,
84
- normalize: bool=False,
85
- centralize:bool = False,
119
+ concat_params: bool = True,
86
120
  inner: Chainable | None = None,
87
121
  ):
88
122
  # history is still updated each step so Precondition's update_freq has different meaning
89
- defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, S_beta=S_beta, normalize=normalize, centralize=centralize)
123
+ defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
90
124
  super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
91
125
 
92
126
  @torch.no_grad
93
- def update_tensor(self, tensor, param, grad, loss, state, settings):
94
- order = settings['order']
95
- history_size = settings['history_size']
96
- update_freq = settings['update_freq']
97
- damping = settings['damping']
98
- rdamping = settings['rdamping']
99
- true_damping = settings['true_damping']
100
- U_beta = settings['U_beta']
101
- S_beta = settings['S_beta']
102
- normalize = settings['normalize']
103
- centralize = settings['centralize']
127
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
128
+ order = setting['order']
129
+ history_size = setting['history_size']
130
+ update_freq = setting['update_freq']
131
+ damping = setting['damping']
132
+ rdamping = setting['rdamping']
133
+ U_beta = setting['U_beta']
134
+ L_beta = setting['L_beta']
104
135
 
105
136
  if 'history' not in state: state['history'] = deque(maxlen=history_size)
106
137
  history = state['history']
107
138
 
108
139
  if order == 1:
109
140
  t = tensor.clone().view(-1)
110
- if centralize: t -= t.mean()
111
- if normalize: t /= torch.linalg.vector_norm(t).clip(min=1e-8) # pylint:disable=not-callable
112
141
  history.append(t)
113
142
  else:
114
143
 
@@ -116,48 +145,42 @@ class SpectralPreconditioner(TensorwiseTransform):
116
145
  # scaled by parameter differences
117
146
  cur_p = param.clone()
118
147
  cur_g = tensor.clone()
148
+ eps = torch.finfo(cur_p.dtype).tiny * 2
119
149
  for i in range(1, order):
120
150
  if f'prev_g_{i}' not in state:
121
151
  state[f'prev_p_{i}'] = cur_p
122
152
  state[f'prev_g_{i}'] = cur_g
123
153
  break
124
154
 
125
- s_k = cur_p - state[f'prev_p_{i}']
126
- y_k = cur_g - state[f'prev_g_{i}']
155
+ s = cur_p - state[f'prev_p_{i}']
156
+ y = cur_g - state[f'prev_g_{i}']
127
157
  state[f'prev_p_{i}'] = cur_p
128
158
  state[f'prev_g_{i}'] = cur_g
129
- cur_p = s_k
130
- cur_g = y_k
159
+ cur_p = s
160
+ cur_g = y
131
161
 
132
162
  if i == order - 1:
133
- if centralize: cur_g = cur_g - cur_g.mean()
134
- if normalize: cur_g = cur_g / torch.linalg.vector_norm(cur_g).clip(min=1e-8) # pylint:disable=not-callable
135
- else: cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
163
+ cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
136
164
  history.append(cur_g.view(-1))
137
165
 
138
166
  step = state.get('step', 0)
139
167
  if step % update_freq == 0 and len(history) != 0:
140
- U, S_inv = spectral_update(history, damping=damping, rdamping=rdamping, true_damping=true_damping)
168
+ U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
141
169
  maybe_lerp_(state, U_beta, 'U', U)
142
- maybe_lerp_(state, S_beta, 'S_inv', S_inv)
170
+ maybe_lerp_(state, L_beta, 'L', L)
143
171
 
144
172
  if len(history) != 0:
145
173
  state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
146
174
 
147
175
  @torch.no_grad
148
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
149
- history_size = settings['history_size']
150
-
176
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
151
177
  U = state.get('U', None)
152
178
  if U is None:
153
179
  # make a conservative step to avoid issues due to different GD scaling
154
180
  return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
155
181
 
156
- S_inv = state['S_inv']
157
- update = spectral_apply(tensor.view(-1), U, S_inv).view_as(tensor)
182
+ L = state['L']
183
+ update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
158
184
 
159
- n = len(state['history'])
160
- mh = min(history_size, 10)
161
- if n <= mh: update.mul_(n/mh)
162
185
  return update
163
186