torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,183 @@
1
+ from collections import deque
2
+ from typing import Literal, Any
3
+ import warnings
4
+
5
+ import torch
6
+ from ...core import Chainable, TensorwiseTransform
7
+
8
+ def lm_adagrad_update(history: deque[torch.Tensor], damping, rdamping):
9
+ M = torch.stack(tuple(history), dim=1)# / len(history)
10
+ MTM = M.T @ M
11
+ if damping != 0:
12
+ MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
13
+
14
+ try:
15
+ L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
16
+
17
+ tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
18
+ indices = L > tol
19
+ L = L[indices]
20
+ Q = Q[:, indices]
21
+
22
+ U = (M @ Q) * L.rsqrt()
23
+
24
+ if rdamping != 0:
25
+ rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
26
+ L.add_(rdamping)
27
+
28
+ return U, L
29
+
30
+ except torch.linalg.LinAlgError:
31
+ return None, None
32
+
33
+ def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
34
+ Z = U.T @ g
35
+ return (U * L.rsqrt()) @ Z
36
+
37
+ def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
38
+ if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
39
+ else:
40
+ if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
41
+ else: state_[key].lerp_(value, 1-beta)
42
+
43
+ class LMAdagrad(TensorwiseTransform):
44
+ """
45
+ Limited-memory full matrix Adagrad.
46
+
47
+ The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
48
+ But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
49
+
50
+ This is equivalent to full-matrix Adagrad on recent gradients.
51
+
52
+ Args:
53
+ history_size (int, optional): number of past gradients to store. Defaults to 10.
54
+ update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
55
+ damping (float, optional): damping value. Defaults to 1e-4.
56
+ rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
57
+ order (int, optional):
58
+ order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
59
+ true_damping (bool, optional):
60
+ If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
61
+ eigh (bool, optional): uses a more efficient way to calculate U and S. Defaults to True.
62
+ U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
63
+ S_beta (float | None, optional): momentum for S (too unstable, don't use). Defaults to None.
64
+ interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
65
+ concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
66
+ inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
67
+
68
+ Examples:
69
+ Limited-memory Adagrad
70
+
71
+ .. code-block:: python
72
+
73
+ optimizer = tz.Modular(
74
+ model.parameters(),
75
+ tz.m.LMAdagrad(),
76
+ tz.m.LR(0.1)
77
+ )
78
+
79
+ Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
80
+
81
+ .. code-block:: python
82
+
83
+ optimizer = tz.Modular(
84
+ model.parameters(),
85
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
86
+ tz.m.Debias(0.9, 0.999),
87
+ tz.m.LR(0.01)
88
+ )
89
+
90
+ Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
91
+
92
+ .. code-block:: python
93
+
94
+ optimizer = tz.Modular(
95
+ model.parameters(),
96
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
97
+ tz.m.Debias(0.9, 0.999),
98
+ tz.m.ClipNormByEMA(max_ema_growth=1.2),
99
+ tz.m.LR(0.01)
100
+ )
101
+
102
+ Reference:
103
+ Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ history_size: int = 100,
109
+ update_freq: int = 1,
110
+ damping: float = 1e-4,
111
+ rdamping: float = 0,
112
+ order: int = 1,
113
+ true_damping: bool = True,
114
+ U_beta: float | None = None,
115
+ L_beta: float | None = None,
116
+ interval: int = 1,
117
+ concat_params: bool = True,
118
+ inner: Chainable | None = None,
119
+ ):
120
+ # history is still updated each step so Precondition's update_freq has different meaning
121
+ defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
122
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
123
+
124
+ @torch.no_grad
125
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
126
+ order = setting['order']
127
+ history_size = setting['history_size']
128
+ update_freq = setting['update_freq']
129
+ damping = setting['damping']
130
+ rdamping = setting['rdamping']
131
+ U_beta = setting['U_beta']
132
+ L_beta = setting['L_beta']
133
+
134
+ if 'history' not in state: state['history'] = deque(maxlen=history_size)
135
+ history = state['history']
136
+
137
+ if order == 1:
138
+ t = tensor.clone().view(-1)
139
+ history.append(t)
140
+ else:
141
+
142
+ # if order=2, history is of gradient differences, order 3 is differences between differences, etc
143
+ # scaled by parameter differences
144
+ cur_p = param.clone()
145
+ cur_g = tensor.clone()
146
+ for i in range(1, order):
147
+ if f'prev_g_{i}' not in state:
148
+ state[f'prev_p_{i}'] = cur_p
149
+ state[f'prev_g_{i}'] = cur_g
150
+ break
151
+
152
+ s = cur_p - state[f'prev_p_{i}']
153
+ y = cur_g - state[f'prev_g_{i}']
154
+ state[f'prev_p_{i}'] = cur_p
155
+ state[f'prev_g_{i}'] = cur_g
156
+ cur_p = s
157
+ cur_g = y
158
+
159
+ if i == order - 1:
160
+ cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
161
+ history.append(cur_g.view(-1))
162
+
163
+ step = state.get('step', 0)
164
+ if step % update_freq == 0 and len(history) != 0:
165
+ U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
166
+ maybe_lerp_(state, U_beta, 'U', U)
167
+ maybe_lerp_(state, L_beta, 'L', L)
168
+
169
+ if len(history) != 0:
170
+ state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
171
+
172
+ @torch.no_grad
173
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
174
+ U = state.get('U', None)
175
+ if U is None:
176
+ # make a conservative step to avoid issues due to different GD scaling
177
+ return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
178
+
179
+ L = state['L']
180
+ update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
181
+
182
+ return update
183
+
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from ...core import Module, Target, Transform
4
- from ...utils import NumberList, TensorList
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
5
 
6
6
 
7
7
  def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
@@ -28,8 +28,8 @@ class Lion(Transform):
28
28
  super().__init__(defaults, uses_grad=False)
29
29
 
30
30
  @torch.no_grad
31
- def transform(self, tensors, params, grads, vars):
32
- beta1, beta2 = self.get_settings('beta1', 'beta2', params = params, cls=NumberList)
33
- exp_avg = self.get_state('ema', params=params, cls=TensorList)
31
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
32
+ beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
33
+ exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
34
34
  return lion_(TensorList(tensors),exp_avg,beta1,beta2)
35
35
 
@@ -0,0 +1,91 @@
1
+ from operator import itemgetter
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform, apply_transform, Chainable
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..step_size.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def mars_correction_(
19
+ tensors_: TensorList,
20
+ prev_: TensorList,
21
+ beta: float | NumberList,
22
+ scaling: float | NumberList,
23
+ max_norm: float | NumberList | None,
24
+ ):
25
+ dg = (tensors_ - prev_).mul_(scaling * beta / (1-beta))
26
+ prev_.copy_(tensors_)
27
+
28
+ c = tensors_.add_(dg)
29
+ if max_norm is not None:
30
+ c.clip_norm_(max=max_norm, tensorwise=False)
31
+
32
+ return c
33
+
34
+ class MARSCorrection(Transform):
35
+ """MARS variance reduction correction.
36
+
37
+ Place any other momentum-based optimizer after this,
38
+ make sure :code:`beta` parameter matches with momentum in the optimizer.
39
+
40
+ Args:
41
+ beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
42
+ scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
43
+ max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
44
+
45
+ Examples:
46
+ Mars-AdamW
47
+
48
+ .. code-block:: python
49
+
50
+ optimizer = tz.Modular(
51
+ model.parameters(),
52
+ tz.m.MARSCorrection(beta=0.95),
53
+ tz.m.Adam(beta1=0.95, beta2=0.99),
54
+ tz.m.WeightDecay(1e-3),
55
+ tz.m.LR(0.1)
56
+ )
57
+
58
+ Mars-Lion
59
+
60
+ .. code-block:: python
61
+
62
+ optimizer = tz.Modular(
63
+ model.parameters(),
64
+ tz.m.MARSCorrection(beta=0.9),
65
+ tz.m.Lion(beta1=0.9),
66
+ tz.m.LR(0.1)
67
+ )
68
+
69
+ """
70
+ def __init__(
71
+ self,
72
+ beta: float = 0.9,
73
+ scaling: float = 0.025,
74
+ max_norm: float | None = 1,
75
+ ):
76
+ defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
77
+ super().__init__(defaults, uses_grad=False)
78
+
79
+ @torch.no_grad
80
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
81
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
82
+ beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
83
+ max_norm = settings[0]['max_norm']
84
+
85
+ return mars_correction_(
86
+ tensors_=TensorList(tensors),
87
+ prev_=prev,
88
+ beta=beta,
89
+ scaling=scaling,
90
+ max_norm=max_norm,
91
+ )
@@ -0,0 +1,186 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Module, Target, Transform, apply_transform
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
7
+ from ..functional import ema_
8
+ from ..momentum.momentum import nag_
9
+
10
+
11
+ def msam_(
12
+ tensors: TensorList,
13
+ params: TensorList,
14
+ velocity_: TensorList,
15
+ momentum: float | NumberList,
16
+ lr: NumberList | None,
17
+ rho: float | NumberList,
18
+ weight_decay: float | NumberList,
19
+ nesterov: bool = False,
20
+ lerp: bool = False,
21
+
22
+ # inner args
23
+ inner: Module | None = None,
24
+ grads: list[torch.Tensor] | None = None,
25
+ ):
26
+ # weights w and wh, momentum μ, perturbation strength ρ
27
+ # w = wh + rho * v / ||v||
28
+ # v1 = μv + g
29
+ # w1 = w - lr*v1
30
+ # wh1 = w1 - rho * v1 / ||v1||
31
+
32
+ # w1 = wh + rho * v / ||v|| - lr*v1
33
+ # vn = rho * v / ||v||
34
+ # v1n = rho * v1 / ||v1||
35
+ # wh1 = wh + vn - lr*v1 - v1n
36
+
37
+ # the update is
38
+ # vn - lr*v1 - v1n
39
+
40
+ # we track ascent direction so it becomes lr*v1 + v1n - vn
41
+
42
+ # can't really decouple it from lr
43
+ # but at least it is now expressed as function of g
44
+
45
+ denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
46
+ vn = velocity_ / denom
47
+
48
+ mom_ = nag_ if nesterov else ema_
49
+ velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
50
+
51
+ denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
52
+ v1n = velocity_ / denom
53
+
54
+ if inner is not None:
55
+ assert params is not None
56
+ inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
57
+
58
+ else:
59
+ assert lr is not None
60
+ inner_update = velocity_ * lr
61
+
62
+ update = inner_update.add_(v1n).sub_(vn)
63
+
64
+ if generic_ne(weight_decay, 0):
65
+ wd = (params + vn).mul_(weight_decay)
66
+ update.add_(wd)
67
+
68
+ return update
69
+
70
+ class MSAM(Transform):
71
+ """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
72
+
73
+ This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
74
+ replacement for momentum strategies in other optimizers.
75
+
76
+ To combine MSAM with other optimizers in the way done in the official implementation,
77
+ e.g. to make Adam_MSAM, use :code:`tz.m.MSAMObjective` module.
78
+
79
+ .. note::
80
+ MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
81
+ To avoid compounding learning rate mofications, remove the :code:`tz.m.LR` module if you had it.
82
+
83
+ Args:
84
+ lr (float): learning rate. Adding this module adds support for learning rate schedulers.
85
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
86
+ rho (float, optional): perturbation strength. Defaults to 0.3.
87
+ weight_decay (float, optional):
88
+ weight decay. It is applied to perturbed parameters, so it is differnet
89
+ from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
90
+ nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
91
+ lerp (bool, optional):
92
+ whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
93
+
94
+ Examples:
95
+ MSAM
96
+
97
+ .. code-block:: python
98
+
99
+ opt = tz.Modular(
100
+ model.parameters(),
101
+ tz.m.MSAM(1e-3)
102
+ )
103
+
104
+ Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
105
+ To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
106
+
107
+ .. code-block:: python
108
+
109
+ opt = tz.Modular(
110
+ model.parameters(),
111
+ tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
112
+ tz.m.Debias(0.9, 0.999),
113
+ )
114
+ """
115
+ USES_LR = True
116
+ def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
117
+ defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
118
+ if self.USES_LR: defaults['lr'] = lr
119
+ super().__init__(defaults, uses_grad=False)
120
+
121
+ @torch.no_grad
122
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
123
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
124
+ s = self.settings[params[0]]
125
+ lerp = s['lerp']
126
+ nesterov = s['nesterov']
127
+
128
+ if self.USES_LR:
129
+ lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
130
+
131
+ else:
132
+ lr=None
133
+ momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
134
+
135
+ return msam_(
136
+ TensorList(tensors),
137
+ params=TensorList(params),
138
+ velocity_=velocity,
139
+ momentum=momentum,
140
+ lr=lr,
141
+ rho=rho,
142
+ weight_decay=weight_decay,
143
+ nesterov=nesterov,
144
+ lerp=lerp,
145
+
146
+ # inner args
147
+ inner=self.children.get("modules", None),
148
+ grads=grads,
149
+ )
150
+
151
+
152
+ class MSAMObjective(MSAM):
153
+ """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
154
+
155
+ .. note::
156
+ Please make sure to place :code:`tz.m.LR` inside the :code:`modules` argument. For example,
157
+ :code:`tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])`. Putting LR after MSAM will lead
158
+ to an incorrect update rule.
159
+
160
+ Args:
161
+ modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
162
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
163
+ rho (float, optional): perturbation strength. Defaults to 0.3.
164
+ nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
165
+ lerp (bool, optional):
166
+ whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
167
+ Defaults to False.
168
+
169
+ Examples:
170
+ AdamW-MSAM
171
+
172
+ .. code-block:: python
173
+
174
+ opt = tz.Modular(
175
+ bench.parameters(),
176
+ tz.m.MSAMObjective(
177
+ [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
178
+ rho=1.
179
+ )
180
+ )
181
+ """
182
+ USES_LR = False
183
+ def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
184
+ super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
185
+ self.set_child('modules', modules)
186
+
@@ -19,6 +19,7 @@ def _is_at_least_2d(p: torch.Tensor):
19
19
 
20
20
  # stolen from:
21
21
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
22
+ # actually at this stage its a frankenstein
22
23
  @enable_compilation
23
24
  def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
24
25
  """
@@ -152,7 +153,7 @@ class Orthogonalize(TensorwiseTransform):
152
153
  The Muon page says that embeddings and classifier heads should not be orthogonalized.
153
154
  Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
154
155
 
155
- To make Muon, use Split with Adam on 1d params: TODO code example.
156
+ To make Muon, use Split with Adam on 1d params
156
157
 
157
158
  Args:
158
159
  ns_steps (int, optional):
@@ -164,7 +165,31 @@ class Orthogonalize(TensorwiseTransform):
164
165
  method (str, optional):
165
166
  Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
166
167
  target (str, optional):
167
- what to set on vars.
168
+ what to set on var.
169
+
170
+
171
+ Examples:
172
+ standard Muon with Adam fallback
173
+
174
+ .. code-block:: python
175
+
176
+ opt = tz.Modular(
177
+ model.head.parameters(),
178
+ tz.m.Split(
179
+ # apply muon only to 2D+ parameters
180
+ filter = lambda t: t.ndim >= 2,
181
+ true = [
182
+ tz.m.HeavyBall(),
183
+ tz.m.Orthogonalize(),
184
+ tz.m.LR(1e-2),
185
+ ],
186
+ false = tz.m.Adam()
187
+ ),
188
+ tz.m.LR(1e-2)
189
+ )
190
+
191
+ Reference:
192
+ Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
168
193
  """
169
194
  def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
170
195
  method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
@@ -172,9 +197,9 @@ class Orthogonalize(TensorwiseTransform):
172
197
  super().__init__(uses_grad=False, defaults=defaults, target=target)
173
198
 
174
199
  @torch.no_grad
175
- def transform(self, tensor, param, grad, vars):
200
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
176
201
  orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
177
- 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(self.settings[param])
202
+ 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
178
203
 
179
204
  if not orthogonalize: return tensor
180
205
 
@@ -199,7 +224,7 @@ class DualNormCorrection(TensorwiseTransform):
199
224
  def __init__(self, target: Target='update'):
200
225
  super().__init__({}, uses_grad=True, target=target)
201
226
 
202
- def transform(self, tensor, param, grad, vars):
227
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
203
228
  assert grad is not None
204
229
  if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
205
230
  return _dual_norm_correction(tensor, grad, batch_first=False)
@@ -213,8 +238,8 @@ class MuonAdjustLR(Transform):
213
238
  defaults = dict(alpha=alpha)
214
239
  super().__init__(defaults=defaults, uses_grad=False, target=target)
215
240
 
216
- def transform(self, tensors, params, grads, vars):
217
- alphas = self.get_settings('alpha', params=params)
241
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
242
+ alphas = [s['alpha'] for s in settings]
218
243
  tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
219
244
  tensors = [i[0] for i in tensors_alphas]
220
245
  a = [i[1] for i in alphas]
@@ -30,16 +30,15 @@ class OrthoGrad(Transform):
30
30
  Args:
31
31
  eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
32
32
  renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
33
- target (Target, optional): what to set on vars. Defaults to 'update'.
33
+ target (Target, optional): what to set on var. Defaults to 'update'.
34
34
  """
35
35
  def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
36
36
  defaults = dict(eps=eps, renormalize=renormalize)
37
37
  super().__init__(defaults, uses_grad=False, target=target)
38
38
 
39
- def transform(self, tensors, params, grads, vars):
40
- settings = self.settings[params[0]]
41
- eps = settings['eps']
42
- renormalize = settings['renormalize']
39
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
+ eps = settings[0]['eps']
41
+ renormalize = settings[0]['renormalize']
43
42
 
44
43
  params = as_tensorlist(params)
45
44
  target = as_tensorlist(tensors)
@@ -3,8 +3,8 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Module, Target, Transform, Chainable, Vars, apply
7
- from ...utils import NumberList, TensorList
6
+ from ...core import Module, Target, Transform, Chainable, Var, apply_transform
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
8
  from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
9
9
 
10
10
 
@@ -23,7 +23,6 @@ def rmsprop_(
23
23
  inner: Module | None = None,
24
24
  params: list[torch.Tensor] | None = None,
25
25
  grads: list[torch.Tensor] | None = None,
26
- vars: Vars | None = None,
27
26
  ):
28
27
  """returns `tensors_`"""
29
28
  if exp_avg_ is not None:
@@ -36,12 +35,14 @@ def rmsprop_(
36
35
 
37
36
  if inner is not None:
38
37
  assert params is not None
39
- tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
38
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
40
39
 
41
40
  return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
42
41
 
43
42
  class RMSprop(Transform):
44
- """Divides graient by EMA of gradient squares. Matches pytorch RMSprop if "init" is set to "zeros".
43
+ """Divides graient by EMA of gradient squares.
44
+
45
+ This implementation is identical to :code:`torch.optim.RMSprop`.
45
46
 
46
47
  Args:
47
48
  smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
@@ -51,7 +52,8 @@ class RMSprop(Transform):
51
52
  amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
52
53
  pow (float, optional): power used in second momentum power and root. Defaults to 2.
53
54
  init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
54
- inner (Chainable | None, optional): Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
55
+ inner (Chainable | None, optional):
56
+ Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
55
57
  """
56
58
  def __init__(
57
59
  self,
@@ -61,26 +63,25 @@ class RMSprop(Transform):
61
63
  debiased: bool = False,
62
64
  amsgrad: bool = False,
63
65
  pow: float = 2,
64
- init: Literal["zeros", "update"] = "update",
66
+ init: Literal["zeros", "update"] = "zeros",
65
67
  inner: Chainable | None = None,
66
68
  ):
67
69
  defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
68
70
  super().__init__(defaults=defaults, uses_grad=False)
69
- self.current_step = 0
71
+
70
72
  if inner is not None:
71
73
  self.set_child('inner', inner)
72
74
 
73
- def transform(self, tensors, params, grads, vars):
74
- self.current_step += 1
75
-
76
- smoothing,eps = self.get_settings('smoothing', 'eps', params=params, cls=NumberList)
77
- centered,debiased,amsgrad,pow,init = itemgetter('centered','debiased','amsgrad','pow','init')(self.settings[params[0]])
75
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
76
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
77
+ smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
78
+ centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
78
79
 
79
- exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
80
- exp_avg = self.get_state('exp_avg', params=params, cls=TensorList) if centered else None
81
- max_exp_avg_sq = self.get_state('max_exp_avg_sq', params=params, cls=TensorList) if amsgrad else None
80
+ exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
81
+ exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList) if centered else None
82
+ max_exp_avg_sq = unpack_states(states, tensors, 'max_exp_avg_sq', cls=TensorList) if amsgrad else None
82
83
 
83
- if init == 'update' and self.current_step == 1:
84
+ if init == 'update' and step == 1:
84
85
  exp_avg_sq.set_([t**2 for t in tensors])
85
86
  if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
86
87
 
@@ -90,7 +91,7 @@ class RMSprop(Transform):
90
91
  smoothing=smoothing,
91
92
  eps=eps,
92
93
  debiased=debiased,
93
- step=self.current_step,
94
+ step=step,
94
95
  exp_avg_=exp_avg,
95
96
  max_exp_avg_sq_=max_exp_avg_sq,
96
97
  pow=pow,
@@ -99,5 +100,4 @@ class RMSprop(Transform):
99
100
  inner=self.children.get("inner", None),
100
101
  params=params,
101
102
  grads=grads,
102
- vars=vars,
103
103
  )