torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Target, Transform, apply_transform
8
+ from ...utils import NumberList, TensorList, as_tensorlist
9
+
10
+
11
+ def esgd_(
12
+ tensors_: TensorList,
13
+ D: TensorList | None,
14
+ D_sq_acc_: TensorList,
15
+ damping: float | NumberList,
16
+ update_freq: int,
17
+ step: int,
18
+ i: int,
19
+ ):
20
+ # update preconditioner
21
+ if step % update_freq == 0:
22
+ assert D is not None
23
+ D_sq_acc_.addcmul_(D, D)
24
+ i += 1
25
+ else:
26
+ assert D is None
27
+
28
+ denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
29
+ return tensors_.div_(denom), i
30
+
31
+
32
+ class ESGD(Module):
33
+ """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
34
+
35
+ This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
36
+
37
+ .. note::
38
+ In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
39
+
40
+ .. note::
41
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
42
+
43
+ .. note::
44
+ This module requires a closure passed to the optimizer step,
45
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
46
+ The closure must accept a ``backward`` argument (refer to documentation).
47
+
48
+ Args:
49
+ damping (float, optional): added to denominator for stability. Defaults to 1e-4.
50
+ update_freq (int, optional):
51
+ frequency of updating hessian diagonal estimate via a hessian-vector product.
52
+ This value can be increased to reduce computational cost. Defaults to 20.
53
+ hvp_method (str, optional):
54
+ Determines how Hessian-vector products are evaluated.
55
+
56
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
57
+ This requires creating a graph for the gradient.
58
+ - ``"forward"``: Use a forward finite difference formula to
59
+ approximate the HVP. This requires one extra gradient evaluation.
60
+ - ``"central"``: Use a central finite difference formula for a
61
+ more accurate HVP approximation. This requires two extra
62
+ gradient evaluations.
63
+ Defaults to "autograd".
64
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
65
+ n_samples (int, optional):
66
+ number of hessian-vector products with random vectors to evaluate each time when updating
67
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
68
+ seed (int | None, optional): seed for random vectors. Defaults to None.
69
+ inner (Chainable | None, optional):
70
+ Inner module. If this is specified, operations are performed in the following order.
71
+ 1. compute hessian diagonal estimate.
72
+ 2. pass inputs to :code:`inner`.
73
+ 3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
74
+
75
+ Examples:
76
+ Using ESGD:
77
+
78
+ .. code-block:: python
79
+
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.ESGD(),
83
+ tz.m.LR(0.1)
84
+ )
85
+
86
+ ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
87
+ ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
88
+
89
+ .. code-block:: python
90
+
91
+ opt = tz.Modular(
92
+ model.parameters(),
93
+ tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
94
+ tz.m.LR(0.1)
95
+ )
96
+
97
+ """
98
+ def __init__(
99
+ self,
100
+ damping: float = 1e-4,
101
+ update_freq: int = 20,
102
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
103
+ fd_h: float = 1e-3,
104
+ n_samples = 1,
105
+ seed: int | None = None,
106
+ inner: Chainable | None = None
107
+ ):
108
+ defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
109
+ super().__init__(defaults)
110
+
111
+ if inner is not None:
112
+ self.set_child('inner', inner)
113
+
114
+ @torch.no_grad
115
+ def step(self, var):
116
+ params = var.params
117
+ settings = self.settings[params[0]]
118
+ hvp_method = settings['hvp_method']
119
+ fd_h = settings['fd_h']
120
+ update_freq = settings['update_freq']
121
+ n_samples = settings['n_samples']
122
+
123
+ seed = settings['seed']
124
+ generator = None
125
+ if seed is not None:
126
+ if 'generator' not in self.global_state:
127
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
128
+ generator = self.global_state['generator']
129
+
130
+ damping = self.get_settings(params, 'damping', cls=NumberList)
131
+ D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
132
+ i = self.global_state.get('i', 0)
133
+
134
+ step = self.global_state.get('step', 0)
135
+ self.global_state['step'] = step + 1
136
+
137
+ closure = var.closure
138
+ assert closure is not None
139
+
140
+ D = None
141
+ if step % update_freq == 0:
142
+
143
+ rgrad=None
144
+ for j in range(n_samples):
145
+ u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
146
+
147
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
148
+ h=fd_h, normalize=True, retain_grad=j < n_samples-1)
149
+
150
+ if D is None: D = Hvp
151
+ else: torch._foreach_add_(D, Hvp)
152
+
153
+ assert D is not None
154
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
155
+
156
+ D = TensorList(D)
157
+
158
+ update = var.get_update()
159
+ if 'inner' in self.children:
160
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
161
+
162
+ var.update, self.global_state['i'] = esgd_(
163
+ tensors_=TensorList(update),
164
+ D=TensorList(D) if D is not None else None,
165
+ D_sq_acc_=D_sq_acc,
166
+ damping=damping,
167
+ update_freq=update_freq,
168
+ step=step,
169
+ i=i,
170
+ )
171
+ return var
@@ -1,55 +1,53 @@
1
- from abc import ABC, abstractmethod
2
- import math
3
1
  from collections import deque
4
2
  from typing import Literal, Any
5
- import itertools
3
+ import warnings
6
4
 
7
5
  import torch
8
6
  from ...core import Chainable, TensorwiseTransform
9
- from ...utils.linalg.matrix_funcs import matrix_power_eigh
10
- from ...utils.linalg.svd import randomized_svd
11
- from ...utils.linalg.qr import qr_householder
12
7
 
13
- def spectral_update(history, damping, rdamping, true_damping: bool):
14
- M_hist = torch.stack(tuple(history), dim=1)
15
- device = M_hist.device
16
- M_hist = M_hist.cuda()
8
+ def lm_adagrad_update(history: deque[torch.Tensor], damping, rdamping):
9
+ M = torch.stack(tuple(history), dim=1)# / len(history)
10
+ MTM = M.T @ M
11
+ if damping != 0:
12
+ MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
17
13
 
18
14
  try:
19
- U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver='gesvda') # pylint:disable=not-callable
20
- U = U.to(device); S = S.to(device)
15
+ L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
21
16
 
22
- if damping != 0 or rdamping != 0:
23
- if rdamping != 0: rdamping *= torch.linalg.vector_norm(S) # pylint:disable=not-callable
24
- Iu = damping + rdamping
25
- if true_damping:
26
- S.pow_(2)
27
- Iu **= 2
28
- S.add_(Iu)
29
- if true_damping: S.sqrt_()
17
+ tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
18
+ indices = L > tol
19
+ L = L[indices]
20
+ Q = Q[:, indices]
30
21
 
31
- return U, 1/S
22
+ U = (M @ Q) * L.rsqrt()
23
+
24
+ if rdamping != 0:
25
+ rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
26
+ L.add_(rdamping)
27
+
28
+ return U, L
32
29
 
33
30
  except torch.linalg.LinAlgError:
34
31
  return None, None
35
32
 
36
- def spectral_apply(g: torch.Tensor, U: torch.Tensor, S_inv: torch.Tensor):
37
- Utg = (U.T @ g)*S_inv
38
- return U @ Utg
39
-
33
+ def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
34
+ Z = U.T @ g
35
+ return (U * L.rsqrt()) @ Z
40
36
 
41
37
  def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
42
38
  if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
43
39
  else:
44
- if state_[key].shape != value.shape: state_[key] = value
40
+ if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
45
41
  else: state_[key].lerp_(value, 1-beta)
46
42
 
47
- class SpectralPreconditioner(TensorwiseTransform):
43
+ class LMAdagrad(TensorwiseTransform):
48
44
  """
49
- The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate U (Uᵀg)/S.
50
- This is equivalent to full matrix Adagrad with accumulator initialized to zeros,
51
- except only recent :code:`history_size` gradients are used.
52
- However this doesn't require N^2 memory and is computationally less expensive than Shampoo.
45
+ Limited-memory full matrix Adagrad.
46
+
47
+ The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
48
+ But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
49
+
50
+ This is equivalent to full-matrix Adagrad on recent gradients.
53
51
 
54
52
  Args:
55
53
  history_size (int, optional): number of past gradients to store. Defaults to 10.
@@ -60,55 +58,84 @@ class SpectralPreconditioner(TensorwiseTransform):
60
58
  order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
61
59
  true_damping (bool, optional):
62
60
  If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
61
+ eigh (bool, optional): uses a more efficient way to calculate U and S. Defaults to True.
63
62
  U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
64
- S_beta (float | None, optional): momentum for 1/S (too unstable, don't use). Defaults to None.
63
+ S_beta (float | None, optional): momentum for S (too unstable, don't use). Defaults to None.
65
64
  interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
66
- concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to False.
67
- normalize (bool, optional): whether to normalize gradients, this doesn't work well so don't use it. Defaults to False.
68
- centralize (bool, optional): whether to centralize gradients, this doesn't work well so don't use it. Defaults to False.
65
+ concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
69
66
  inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
67
+
68
+ Examples:
69
+ Limited-memory Adagrad
70
+
71
+ .. code-block:: python
72
+
73
+ optimizer = tz.Modular(
74
+ model.parameters(),
75
+ tz.m.LMAdagrad(),
76
+ tz.m.LR(0.1)
77
+ )
78
+
79
+ Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
80
+
81
+ .. code-block:: python
82
+
83
+ optimizer = tz.Modular(
84
+ model.parameters(),
85
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
86
+ tz.m.Debias(0.9, 0.999),
87
+ tz.m.LR(0.01)
88
+ )
89
+
90
+ Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
91
+
92
+ .. code-block:: python
93
+
94
+ optimizer = tz.Modular(
95
+ model.parameters(),
96
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
97
+ tz.m.Debias(0.9, 0.999),
98
+ tz.m.ClipNormByEMA(max_ema_growth=1.2),
99
+ tz.m.LR(0.01)
100
+ )
101
+
102
+ Reference:
103
+ Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
70
104
  """
71
105
 
72
106
  def __init__(
73
107
  self,
74
- history_size: int = 10,
108
+ history_size: int = 100,
75
109
  update_freq: int = 1,
76
110
  damping: float = 1e-4,
77
111
  rdamping: float = 0,
78
112
  order: int = 1,
79
113
  true_damping: bool = True,
80
114
  U_beta: float | None = None,
81
- S_beta: float | None = None,
115
+ L_beta: float | None = None,
82
116
  interval: int = 1,
83
- concat_params: bool = False,
84
- normalize: bool=False,
85
- centralize:bool = False,
117
+ concat_params: bool = True,
86
118
  inner: Chainable | None = None,
87
119
  ):
88
120
  # history is still updated each step so Precondition's update_freq has different meaning
89
- defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, S_beta=S_beta, normalize=normalize, centralize=centralize)
121
+ defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
90
122
  super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
91
123
 
92
124
  @torch.no_grad
93
- def update_tensor(self, tensor, param, grad, loss, state, settings):
94
- order = settings['order']
95
- history_size = settings['history_size']
96
- update_freq = settings['update_freq']
97
- damping = settings['damping']
98
- rdamping = settings['rdamping']
99
- true_damping = settings['true_damping']
100
- U_beta = settings['U_beta']
101
- S_beta = settings['S_beta']
102
- normalize = settings['normalize']
103
- centralize = settings['centralize']
125
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
126
+ order = setting['order']
127
+ history_size = setting['history_size']
128
+ update_freq = setting['update_freq']
129
+ damping = setting['damping']
130
+ rdamping = setting['rdamping']
131
+ U_beta = setting['U_beta']
132
+ L_beta = setting['L_beta']
104
133
 
105
134
  if 'history' not in state: state['history'] = deque(maxlen=history_size)
106
135
  history = state['history']
107
136
 
108
137
  if order == 1:
109
138
  t = tensor.clone().view(-1)
110
- if centralize: t -= t.mean()
111
- if normalize: t /= torch.linalg.vector_norm(t).clip(min=1e-8) # pylint:disable=not-callable
112
139
  history.append(t)
113
140
  else:
114
141
 
@@ -122,42 +149,35 @@ class SpectralPreconditioner(TensorwiseTransform):
122
149
  state[f'prev_g_{i}'] = cur_g
123
150
  break
124
151
 
125
- s_k = cur_p - state[f'prev_p_{i}']
126
- y_k = cur_g - state[f'prev_g_{i}']
152
+ s = cur_p - state[f'prev_p_{i}']
153
+ y = cur_g - state[f'prev_g_{i}']
127
154
  state[f'prev_p_{i}'] = cur_p
128
155
  state[f'prev_g_{i}'] = cur_g
129
- cur_p = s_k
130
- cur_g = y_k
156
+ cur_p = s
157
+ cur_g = y
131
158
 
132
159
  if i == order - 1:
133
- if centralize: cur_g = cur_g - cur_g.mean()
134
- if normalize: cur_g = cur_g / torch.linalg.vector_norm(cur_g).clip(min=1e-8) # pylint:disable=not-callable
135
- else: cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
160
+ cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
136
161
  history.append(cur_g.view(-1))
137
162
 
138
163
  step = state.get('step', 0)
139
164
  if step % update_freq == 0 and len(history) != 0:
140
- U, S_inv = spectral_update(history, damping=damping, rdamping=rdamping, true_damping=true_damping)
165
+ U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
141
166
  maybe_lerp_(state, U_beta, 'U', U)
142
- maybe_lerp_(state, S_beta, 'S_inv', S_inv)
167
+ maybe_lerp_(state, L_beta, 'L', L)
143
168
 
144
169
  if len(history) != 0:
145
170
  state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
146
171
 
147
172
  @torch.no_grad
148
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
149
- history_size = settings['history_size']
150
-
173
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
151
174
  U = state.get('U', None)
152
175
  if U is None:
153
176
  # make a conservative step to avoid issues due to different GD scaling
154
177
  return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
155
178
 
156
- S_inv = state['S_inv']
157
- update = spectral_apply(tensor.view(-1), U, S_inv).view_as(tensor)
179
+ L = state['L']
180
+ update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
158
181
 
159
- n = len(state['history'])
160
- mh = min(history_size, 10)
161
- if n <= mh: update.mul_(n/mh)
162
182
  return update
163
183
 
@@ -28,7 +28,7 @@ class Lion(Transform):
28
28
  super().__init__(defaults, uses_grad=False)
29
29
 
30
30
  @torch.no_grad
31
- def apply(self, tensors, params, grads, loss, states, settings):
31
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
32
32
  beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
33
33
  exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
34
34
  return lion_(TensorList(tensors),exp_avg,beta1,beta2)
@@ -0,0 +1,91 @@
1
+ from operator import itemgetter
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform, apply_transform, Chainable
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..step_size.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def mars_correction_(
19
+ tensors_: TensorList,
20
+ prev_: TensorList,
21
+ beta: float | NumberList,
22
+ scaling: float | NumberList,
23
+ max_norm: float | NumberList | None,
24
+ ):
25
+ dg = (tensors_ - prev_).mul_(scaling * beta / (1-beta))
26
+ prev_.copy_(tensors_)
27
+
28
+ c = tensors_.add_(dg)
29
+ if max_norm is not None:
30
+ c.clip_norm_(max=max_norm, tensorwise=False)
31
+
32
+ return c
33
+
34
+ class MARSCorrection(Transform):
35
+ """MARS variance reduction correction.
36
+
37
+ Place any other momentum-based optimizer after this,
38
+ make sure :code:`beta` parameter matches with momentum in the optimizer.
39
+
40
+ Args:
41
+ beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
42
+ scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
43
+ max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
44
+
45
+ Examples:
46
+ Mars-AdamW
47
+
48
+ .. code-block:: python
49
+
50
+ optimizer = tz.Modular(
51
+ model.parameters(),
52
+ tz.m.MARSCorrection(beta=0.95),
53
+ tz.m.Adam(beta1=0.95, beta2=0.99),
54
+ tz.m.WeightDecay(1e-3),
55
+ tz.m.LR(0.1)
56
+ )
57
+
58
+ Mars-Lion
59
+
60
+ .. code-block:: python
61
+
62
+ optimizer = tz.Modular(
63
+ model.parameters(),
64
+ tz.m.MARSCorrection(beta=0.9),
65
+ tz.m.Lion(beta1=0.9),
66
+ tz.m.LR(0.1)
67
+ )
68
+
69
+ """
70
+ def __init__(
71
+ self,
72
+ beta: float = 0.9,
73
+ scaling: float = 0.025,
74
+ max_norm: float | None = 1,
75
+ ):
76
+ defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
77
+ super().__init__(defaults, uses_grad=False)
78
+
79
+ @torch.no_grad
80
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
81
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
82
+ beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
83
+ max_norm = settings[0]['max_norm']
84
+
85
+ return mars_correction_(
86
+ tensors_=TensorList(tensors),
87
+ prev_=prev,
88
+ beta=beta,
89
+ scaling=scaling,
90
+ max_norm=max_norm,
91
+ )