torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,9 @@ from typing import Literal
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Module, Target, Transform, apply_transform
5
+ from ...core import Chainable, Module, Transform, TensorTransform, step, Objective
6
6
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
7
- from ..functional import ema_
7
+ from ..opt_utils import ema_
8
8
  from ..momentum.momentum import nag_
9
9
 
10
10
 
@@ -21,7 +21,7 @@ def msam_(
21
21
 
22
22
  # inner args
23
23
  inner: Module | None = None,
24
- grads: list[torch.Tensor] | None = None,
24
+ objective: Objective | None = None,
25
25
  ):
26
26
  # weights w and wh, momentum μ, perturbation strength ρ
27
27
  # w = wh + rho * v / ||v||
@@ -54,8 +54,8 @@ def msam_(
54
54
  v1n = velocity_ / denom
55
55
 
56
56
  if inner is not None:
57
- assert params is not None
58
- inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
57
+ assert objective is not None and inner is not None
58
+ inner_update = TensorList(step(objective, inner).get_updates())
59
59
 
60
60
  else:
61
61
  assert lr is not None
@@ -69,7 +69,7 @@ def msam_(
69
69
 
70
70
  return update
71
71
 
72
- class MSAM(Transform):
72
+ class MSAMMomentum(TensorTransform):
73
73
  """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
74
74
 
75
75
  This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
@@ -93,46 +93,40 @@ class MSAM(Transform):
93
93
  lerp (bool, optional):
94
94
  whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
95
95
 
96
- Examples:
97
- MSAM
96
+ ### Examples:
98
97
 
99
- .. code-block:: python
98
+ MSAM
100
99
 
101
- opt = tz.Modular(
102
- model.parameters(),
103
- tz.m.MSAM(1e-3)
104
- )
100
+ ```python
105
101
 
106
- Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
107
- To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
102
+ opt = tz.Optimizer(
103
+ model.parameters(),
104
+ tz.m.MSAM(1e-3)
105
+ )
106
+ ```
108
107
 
109
- .. code-block:: python
108
+ Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
109
+ To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.
110
110
 
111
- opt = tz.Modular(
112
- model.parameters(),
113
- tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
114
- tz.m.Debias(0.9, 0.999),
115
- )
111
+ ```python
112
+ opt = tz.Optimizer(
113
+ model.parameters(),
114
+ tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
115
+ tz.m.Debias(0.9, 0.999),
116
+ )
117
+ ```
116
118
  """
117
- _USES_LR = True
119
+
118
120
  def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
119
- defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
120
- if self._USES_LR: defaults['lr'] = lr
121
+ defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
121
122
  super().__init__(defaults, uses_grad=False)
122
123
 
123
124
  @torch.no_grad
124
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
125
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
125
126
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
126
- s = self.settings[params[0]]
127
- lerp = s['lerp']
128
- nesterov = s['nesterov']
127
+ fs = settings[0]
129
128
 
130
- if self._USES_LR:
131
- lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
132
-
133
- else:
134
- lr=None
135
- momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
129
+ lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
136
130
 
137
131
  return msam_(
138
132
  TensorList(tensors),
@@ -142,16 +136,16 @@ class MSAM(Transform):
142
136
  lr=lr,
143
137
  rho=rho,
144
138
  weight_decay=weight_decay,
145
- nesterov=nesterov,
146
- lerp=lerp,
139
+ nesterov=fs['nesterov'],
140
+ lerp=fs['lerp'],
147
141
 
148
142
  # inner args
149
- inner=self.children.get("modules", None),
150
- grads=grads,
143
+ inner=None,
144
+ objective=None,
151
145
  )
152
146
 
153
147
 
154
- class MSAMObjective(MSAM):
148
+ class MSAM(Transform):
155
149
  """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
156
150
 
157
151
  Note:
@@ -160,7 +154,7 @@ class MSAMObjective(MSAM):
160
154
  to an incorrect update rule.
161
155
 
162
156
  Args:
163
- modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
157
+ modules (Chainable): modules that will optimize the MSAM objective. Make sure ``tz.m.LR`` is one of them.
164
158
  momentum (float, optional): momentum (beta). Defaults to 0.9.
165
159
  rho (float, optional): perturbation strength. Defaults to 0.3.
166
160
  nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
@@ -169,20 +163,44 @@ class MSAMObjective(MSAM):
169
163
  Defaults to False.
170
164
 
171
165
  Examples:
172
- AdamW-MSAM
173
-
174
- .. code-block:: python
175
-
176
- opt = tz.Modular(
177
- bench.parameters(),
178
- tz.m.MSAMObjective(
179
- [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
180
- rho=1.
181
- )
182
- )
166
+ AdamW-MSAM
167
+
168
+ ```py
169
+ opt = tz.Optimizer(
170
+ bench.parameters(),
171
+ tz.m.MSAMObjective(
172
+ [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
173
+ rho=1.
174
+ )
175
+ )
176
+ ```
183
177
  """
184
- _USES_LR = False
185
178
  def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
186
- super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
179
+ defaults = dict(momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
180
+ super().__init__(defaults)
181
+
187
182
  self.set_child('modules', modules)
188
183
 
184
+
185
+ @torch.no_grad
186
+ def apply_states(self, objective, states, settings):
187
+ velocity = unpack_states(states, objective.params, 'velocity', cls=TensorList)
188
+ fs = settings[0]
189
+
190
+ momentum, rho, weight_decay = unpack_dicts(settings, 'momentum', 'rho', 'weight_decay', cls=NumberList)
191
+
192
+ return msam_(
193
+ TensorList(objective.get_updates()),
194
+ params=TensorList(objective.params),
195
+ velocity_=velocity,
196
+ momentum=momentum,
197
+ lr=None,
198
+ rho=rho,
199
+ weight_decay=weight_decay,
200
+ nesterov=fs['nesterov'],
201
+ lerp=fs['lerp'],
202
+
203
+ # inner args
204
+ inner=self.children["modules"],
205
+ objective=objective,
206
+ )
@@ -1,152 +1,85 @@
1
1
  from operator import itemgetter
2
2
  import math
3
- import warnings
4
- from collections.abc import Iterable, Sequence
5
- from typing import Literal
3
+ from collections.abc import Iterable
6
4
 
7
5
  import torch
8
6
 
9
- from ...core import Modular, TensorwiseTransform, Target, Transform
10
- from ...utils import enable_compilation
11
-
7
+ from ...core import TensorTransform, Transform
8
+ from ...linalg.orthogonalize import orthogonalize as _orthogonalize, OrthogonalizeMethod
12
9
 
13
10
  def reverse_dims(t:torch.Tensor):
14
11
  return t.permute(*reversed(range(t.ndim)))
15
12
 
16
- def _is_at_least_2d(p: torch.Tensor):
17
- if (p.ndim >= 2) and (p.size(0) > 1) and (p.size(1) > 1): return True
13
+ def _is_at_least_2d(p: torch.Tensor, channel_first:bool):
14
+ if p.ndim < 2: return False
15
+ if channel_first and (p.size(0) > 1) and (p.size(1) > 1): return True
16
+ if (not channel_first) and (p.size(-2) > 1) and (p.size(-1) > 1): return True
18
17
  return False
19
18
 
20
- # stolen from:
21
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
22
- # actually at this stage its a frankenstein
23
- @enable_compilation
24
- def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
25
- """
26
- Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
27
-
28
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
32
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
34
- performance at all relative to UV^T, where USV^T = G is the SVD.
35
- """
36
- assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
37
- a, b, c = (3.4445, -4.7750, 2.0315)
38
- X = G.bfloat16()
39
- if G.size(-2) > G.size(-1):
40
- X = X.mT
41
-
42
- # Ensure spectral norm is at most 1
43
- X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
44
- # Perform the NS iterations
45
- for _ in range(steps):
46
- A = X @ X.mT
47
- B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
48
- X = a * X + B @ X
49
-
50
- if G.size(-2) > G.size(-1):
51
- X = X.mT
52
- return X
53
-
54
- # stolen from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
55
- # Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
56
- # Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
57
- @torch.no_grad
58
- def _svd_orthogonalize(G: torch.Tensor, warn_fail=True) -> torch.Tensor:
59
- """
60
- Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
61
- """
62
- X = G.view(G.shape[0], -1)
63
-
64
- t = False
65
- if X.size(0) > X.size(1):
66
- X = X.T
67
- t = True
68
-
69
- orth_X: torch.Tensor | None = None
70
- try:
71
- u, s, vt = torch.linalg.svd(X, full_matrices=False) # pylint:disable=not-callable
72
- orth_X = u @ vt
73
- except RuntimeError:
74
- # if warn: logging.warning('Failed to perform SVD, adding some noise.')
75
- try:
76
- u, s, v = torch.svd_lowrank(
77
- X,
78
- q=1, # assume rank is at least 1
79
- M=1e-4 * X.mean() * torch.randn_like(X))
80
- orth_X = u @ v.T
81
- except RuntimeError:
82
- if warn_fail: warnings.warn(('Failed to perform SVD with noise,'
83
- ' skipping gradient orthogonalisation'))
84
- if orth_X is not None:
85
- if t: orth_X = orth_X.T
86
- return orth_X.view_as(G)
87
-
88
- return G # fail
19
+ def _orthogonalize_format(
20
+ tensor: torch.Tensor,
21
+ method: OrthogonalizeMethod,
22
+ channel_first: bool,
23
+ ):
24
+ """orthogonalize either 1st two dims if channel first or last two otherwise"""
25
+ if channel_first:
26
+ return reverse_dims(_orthogonalize(reverse_dims(tensor), method=method))
89
27
 
28
+ return _orthogonalize(tensor, method=method)
90
29
 
91
30
  @torch.no_grad
92
- def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, batch_first):
93
- """batch first means it applies to last 2 dims, otherwise to 1st two dims"""
31
+ def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, channel_first: bool):
32
+ """``channel_first`` means it applies to first two dims, otherwise to last two dims"""
94
33
  # this is from https://github.com/leloykun/adaptive-muon
95
34
  # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
96
- if batch_first: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
97
- else: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
35
+ if channel_first: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
36
+ else: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
98
37
  return X
99
38
 
100
39
 
101
40
  # code from
102
41
  # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
103
- def adjust_lr_for_muon(lr, param_shape):
104
- A, B = param_shape[:2]
42
+ def adjust_lr_for_muon(lr, param_shape, channel_first:bool):
43
+ if channel_first: A, B = param_shape[:2]
44
+ else: A, B = param_shape[-2:]
45
+
105
46
  # We adjust the learning rate and weight decay based on the size of the parameter matrix
106
47
  # as describted in the paper
107
48
  adjusted_ratio = 0.2 * math.sqrt(max(A, B))
108
49
  adjusted_lr = lr * adjusted_ratio
109
50
  return adjusted_lr
110
51
 
111
- def _orthogonalize_tensor(
112
- tensor: torch.Tensor,
113
- steps: int = 5,
114
- method: Literal["newton-schulz", "svd"] = "newton-schulz",
115
- ):
116
- if method == 'newton-schulz': return reverse_dims(zeropower_via_newtonschulz5(reverse_dims(tensor), steps)).type_as(tensor)
117
- if method == 'svd': return _svd_orthogonalize(tensor, False)
118
- raise ValueError(method)
119
-
120
52
 
121
53
  def orthogonalize_grads_(
122
54
  params: Iterable[torch.Tensor],
123
- steps: int = 5,
124
55
  dual_norm_correction=False,
125
- method: Literal["newton-schulz", "svd"] = "newton-schulz",
56
+ method: OrthogonalizeMethod = "newtonschulz",
57
+ channel_first:bool=True,
126
58
  ):
127
- """Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
59
+ """Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.
128
60
 
129
61
  This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).
130
62
 
131
63
  Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
132
64
  Args:
133
65
  params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
134
- steps (int, optional):
135
- The number of Newton-Schulz iterations to run. Defaults to 5.
136
66
  dual_norm_correction (bool, optional):
137
67
  enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
138
68
  method (str, optional):
139
69
  Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
70
+ channel_first (bool, optional):
71
+ if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
72
+ are considered batch dimensions.
140
73
  """
141
74
  for p in params:
142
- if (p.grad is not None) and _is_at_least_2d(p.grad):
143
- X = _orthogonalize_tensor(p.grad, steps, method)
144
- if dual_norm_correction: X = _dual_norm_correction(X, p.grad, batch_first=False)
75
+ if (p.grad is not None) and _is_at_least_2d(p.grad, channel_first=channel_first):
76
+ X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
77
+ if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
145
78
  p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
146
79
 
147
80
 
148
81
 
149
- class Orthogonalize(TensorwiseTransform):
82
+ class Orthogonalize(TensorTransform):
150
83
  """Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.
151
84
 
152
85
  To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
@@ -156,22 +89,21 @@ class Orthogonalize(TensorwiseTransform):
156
89
  To make Muon, use Split with Adam on 1d params
157
90
 
158
91
  Args:
159
- ns_steps (int, optional):
160
- The number of Newton-Schulz iterations to run. Defaults to 5.
161
92
  adjust_lr (bool, optional):
162
93
  Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
163
94
  dual_norm_correction (bool, optional):
164
95
  enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
165
96
  method (str, optional):
166
- Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
167
- target (str, optional):
168
- what to set on var.
97
+ Newton-Schulz is very fast, SVD is slow but can be more precise.
98
+ channel_first (bool, optional):
99
+ if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
100
+ are considered batch dimensions.
169
101
 
170
102
  ## Examples:
171
103
 
172
104
  standard Muon with Adam fallback
173
105
  ```py
174
- opt = tz.Modular(
106
+ opt = tz.Optimizer(
175
107
  model.head.parameters(),
176
108
  tz.m.Split(
177
109
  # apply muon only to 2D+ parameters
@@ -190,56 +122,62 @@ class Orthogonalize(TensorwiseTransform):
190
122
  Reference:
191
123
  Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
192
124
  """
193
- def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
194
- method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
195
- defaults = dict(orthogonalize=True, ns_steps=ns_steps, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower())
196
- super().__init__(uses_grad=False, defaults=defaults, target=target)
125
+ def __init__(self, adjust_lr=False, dual_norm_correction=False,
126
+ method: OrthogonalizeMethod = 'newtonschulz', channel_first:bool=True):
127
+ defaults = dict(orthogonalize=True, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower(), channel_first=channel_first)
128
+ super().__init__(defaults=defaults)
197
129
 
198
130
  @torch.no_grad
199
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
200
- orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
201
- 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
131
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
132
+ orthogonalize, dual_norm_correction, adjust_lr, method, channel_first = itemgetter(
133
+ 'orthogonalize', 'dual_norm_correction', 'adjust_lr', 'method', 'channel_first')(setting)
202
134
 
203
135
  if not orthogonalize: return tensor
204
136
 
205
- if _is_at_least_2d(tensor):
137
+ if _is_at_least_2d(tensor, channel_first=channel_first):
206
138
 
207
- X = _orthogonalize_tensor(tensor, ns_steps, method)
139
+ X = _orthogonalize_format(tensor, method, channel_first=channel_first)
208
140
 
209
141
  if dual_norm_correction:
210
- X = _dual_norm_correction(X, tensor, batch_first=False)
142
+ X = _dual_norm_correction(X, tensor, channel_first=channel_first)
211
143
 
212
144
  if adjust_lr:
213
- X.mul_(adjust_lr_for_muon(1, param.shape))
145
+ X.mul_(adjust_lr_for_muon(1, param.shape, channel_first=channel_first))
214
146
 
215
147
  return X.view_as(param)
216
148
 
217
149
  return tensor
218
150
 
219
151
 
220
- class DualNormCorrection(TensorwiseTransform):
152
+ class DualNormCorrection(TensorTransform):
221
153
  """Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
222
154
  Orthogonalize already has this built in with the `dual_norm_correction` setting."""
223
- def __init__(self, target: Target='update'):
224
- super().__init__({}, uses_grad=True, target=target)
155
+ def __init__(self, channel_first: bool = True):
156
+ defaults = dict(channel_first=channel_first)
157
+ super().__init__(defaults)
225
158
 
226
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
159
+ @torch.no_grad
160
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
227
161
  assert grad is not None
228
162
  if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
229
- return _dual_norm_correction(tensor, grad, batch_first=False)
163
+ return _dual_norm_correction(tensor, grad, channel_first=setting["channel_first"])
230
164
  return tensor
231
165
 
232
166
 
233
167
  class MuonAdjustLR(Transform):
234
168
  """LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
235
- Orthogonalize already has this built in with the `adjust_lr` setting, however you might want to move this to be later in the chain."""
236
- def __init__(self, alpha: float = 1, target: Target='update'):
237
- defaults = dict(alpha=alpha)
238
- super().__init__(defaults=defaults, uses_grad=False, target=target)
169
+ Orthogonalize already has this built in with the ``adjust_lr`` setting, however you might want to move this to be later in the chain."""
170
+ def __init__(self, channel_first: bool = True, alpha: float = 1):
171
+ defaults = dict(channel_first=channel_first, alpha=alpha)
172
+ super().__init__(defaults=defaults)
239
173
 
240
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
174
+ @torch.no_grad
175
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
241
176
  alphas = [s['alpha'] for s in settings]
242
- tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
177
+ channel_first = [s["channel_first=channel_first"] for s in settings]
178
+ tensors_alphas = [
179
+ (t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t, channel_first=cf)
180
+ ]
243
181
  tensors = [i[0] for i in tensors_alphas]
244
182
  a = [i[1] for i in alphas]
245
183
  torch._foreach_mul_(tensors, a)