torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,13 @@
1
1
  from typing import Literal
2
- from collections.abc import Callable
2
+
3
3
  import torch
4
4
 
5
- from ...core import Module, apply_transform, Chainable
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
5
+ from ...core import Chainable, Transform, HVPMethod
6
+ from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
8
7
  from ..functional import initial_step_size
9
8
 
10
9
 
11
- class MatrixMomentum(Module):
10
+ class MatrixMomentum(Transform):
12
11
  """Second order momentum method.
13
12
 
14
13
  Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
@@ -23,17 +22,17 @@ class MatrixMomentum(Module):
23
22
  Args:
24
23
  mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
25
24
  hvp_method (str, optional):
26
- Determines how Hessian-vector products are evaluated.
27
-
28
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
29
- This requires creating a graph for the gradient.
30
- - ``"forward"``: Use a forward finite difference formula to
31
- approximate the HVP. This requires one extra gradient evaluation.
32
- - ``"central"``: Use a central finite difference formula for a
33
- more accurate HVP approximation. This requires two extra
34
- gradient evaluations.
35
- Defaults to "autograd".
36
- h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
25
+ Determines how hessian-vector products are computed.
26
+
27
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
28
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
29
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
30
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
31
+
32
+ Defaults to ``"autograd"``.
33
+ h (float, optional):
34
+ The step size for finite difference if ``hvp_method`` is
35
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
37
36
  hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
38
37
 
39
38
  Reference:
@@ -44,51 +43,45 @@ class MatrixMomentum(Module):
44
43
  self,
45
44
  lr:float,
46
45
  mu=0.1,
47
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
46
+ hvp_method: HVPMethod = "autograd",
48
47
  h: float = 1e-3,
49
48
  adaptive:bool = False,
50
49
  adapt_freq: int | None = None,
51
- hvp_tfm: Chainable | None = None,
50
+
51
+ inner: Chainable | None = None,
52
52
  ):
53
53
  defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
54
- super().__init__(defaults)
55
-
56
- if hvp_tfm is not None:
57
- self.set_child('hvp_tfm', hvp_tfm)
54
+ super().__init__(defaults, inner=inner)
58
55
 
59
56
  def reset_for_online(self):
60
57
  super().reset_for_online()
61
58
  self.clear_state_keys('p_prev')
62
59
 
63
60
  @torch.no_grad
64
- def update(self, var):
65
- assert var.closure is not None
66
- p = TensorList(var.params)
67
- p_prev = self.get_state(p, 'p_prev', init=var.params)
61
+ def update_states(self, objective, states, settings):
62
+ step = self.increment_counter("step", 0)
63
+ p = TensorList(objective.params)
64
+ p_prev = unpack_states(states, p, 'p_prev', init=p)
68
65
 
69
- hvp_method = self.defaults['hvp_method']
70
- h = self.defaults['h']
71
- step = self.global_state.get("step", 0)
72
- self.global_state["step"] = step + 1
66
+ fs = settings[0]
67
+ hvp_method = fs['hvp_method']
68
+ h = fs['h']
73
69
 
74
70
  if step > 0:
75
71
  s = p - p_prev
76
72
 
77
- Hs, _ = var.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_graph=False)
73
+ Hs, _ = objective.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, retain_graph=False)
78
74
  Hs = [t.detach() for t in Hs]
79
75
 
80
- if 'hvp_tfm' in self.children:
81
- Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
82
-
83
76
  self.store(p, ("Hs", "s"), (Hs, s))
84
77
 
85
78
  # -------------------------------- adaptive mu ------------------------------- #
86
- if self.defaults["adaptive"]:
87
- g = TensorList(var.get_grad())
79
+ if fs["adaptive"]:
80
+ g = TensorList(objective.get_grads())
88
81
 
89
- if self.defaults["adapt_freq"] is None:
82
+ if fs["adapt_freq"] is None:
90
83
  # ---------------------------- deterministic case ---------------------------- #
91
- g_prev = self.get_state(var.params, "g_prev", cls=TensorList)
84
+ g_prev = unpack_states(states, p, "g_prev", cls=TensorList)
92
85
  y = g - g_prev
93
86
  g_prev.copy_(g)
94
87
  denom = y.global_vector_norm()
@@ -101,14 +94,14 @@ class MatrixMomentum(Module):
101
94
 
102
95
  # we start on 1nd step, and want to adapt when we start, so use (step - 1)
103
96
  if (step - 1) % adapt_freq == 0:
104
- assert var.closure is not None
105
- params = TensorList(var.params)
97
+ assert objective.closure is not None
98
+ params = TensorList(objective.params)
106
99
  p_cur = params.clone()
107
100
 
108
101
  # move to previous params and evaluate p_prev with current mini-batch
109
- params.copy_(self.get_state(var.params, 'p_prev'))
102
+ params.copy_(unpack_states(states, p, 'p_prev'))
110
103
  with torch.enable_grad():
111
- var.closure()
104
+ objective.closure()
112
105
  g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
113
106
  y = g - g_prev
114
107
 
@@ -119,12 +112,12 @@ class MatrixMomentum(Module):
119
112
  denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
120
113
  self.global_state["mu_mul"] = s.global_vector_norm() / denom
121
114
 
122
- torch._foreach_copy_(p_prev, var.params)
115
+ torch._foreach_copy_(p_prev, objective.params)
123
116
 
124
117
  @torch.no_grad
125
- def apply(self, var):
126
- update = TensorList(var.get_update())
127
- lr,mu = self.get_settings(var.params, "lr", 'mu', cls=NumberList)
118
+ def apply_states(self, objective, states, settings):
119
+ update = TensorList(objective.get_updates())
120
+ lr, mu = unpack_dicts(settings, "lr", 'mu', cls=NumberList)
128
121
 
129
122
  if "mu_mul" in self.global_state:
130
123
  mu = mu * self.global_state["mu_mul"]
@@ -133,14 +126,17 @@ class MatrixMomentum(Module):
133
126
  # p_prev is not available so make a small step
134
127
  step = self.global_state["step"]
135
128
  if step == 1:
136
- if self.defaults["adaptive"]: self.get_state(var.params, "g_prev", init=var.get_grad())
129
+ if self.defaults["adaptive"]:
130
+ # initialize
131
+ unpack_states(states, objective.params, "g_prev", init=objective.get_grads())
132
+
137
133
  update.mul_(lr) # separate so that initial_step_size can clip correctly
138
134
  update.mul_(initial_step_size(update, 1e-7))
139
- return var
135
+ return objective
140
136
 
141
137
  # -------------------------- matrix momentum update -------------------------- #
142
- s, Hs = self.get_state(var.params, 's', 'Hs', cls=TensorList)
138
+ s, Hs = unpack_states(states, objective.params, 's', 'Hs', cls=TensorList)
143
139
 
144
140
  update.mul_(lr).sub_(s).add_(Hs*mu)
145
- var.update = update
146
- return var
141
+ objective.updates = update
142
+ return objective
@@ -2,7 +2,7 @@ from typing import Literal
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Module, Target, Transform, apply_transform
5
+ from ...core import Chainable, Module, Transform, TensorTransform, step, Objective
6
6
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
7
7
  from ..functional import ema_
8
8
  from ..momentum.momentum import nag_
@@ -21,7 +21,7 @@ def msam_(
21
21
 
22
22
  # inner args
23
23
  inner: Module | None = None,
24
- grads: list[torch.Tensor] | None = None,
24
+ objective: Objective | None = None,
25
25
  ):
26
26
  # weights w and wh, momentum μ, perturbation strength ρ
27
27
  # w = wh + rho * v / ||v||
@@ -54,8 +54,8 @@ def msam_(
54
54
  v1n = velocity_ / denom
55
55
 
56
56
  if inner is not None:
57
- assert params is not None
58
- inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
57
+ assert objective is not None and inner is not None
58
+ inner_update = TensorList(step(objective, inner).get_updates())
59
59
 
60
60
  else:
61
61
  assert lr is not None
@@ -69,7 +69,7 @@ def msam_(
69
69
 
70
70
  return update
71
71
 
72
- class MSAM(Transform):
72
+ class MSAMMomentum(TensorTransform):
73
73
  """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
74
74
 
75
75
  This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
@@ -93,46 +93,40 @@ class MSAM(Transform):
93
93
  lerp (bool, optional):
94
94
  whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
95
95
 
96
- Examples:
97
- MSAM
96
+ ### Examples:
98
97
 
99
- .. code-block:: python
98
+ MSAM
100
99
 
101
- opt = tz.Modular(
102
- model.parameters(),
103
- tz.m.MSAM(1e-3)
104
- )
100
+ ```python
105
101
 
106
- Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
107
- To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
102
+ opt = tz.Modular(
103
+ model.parameters(),
104
+ tz.m.MSAM(1e-3)
105
+ )
106
+ ```
108
107
 
109
- .. code-block:: python
108
+ Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
109
+ To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.
110
110
 
111
- opt = tz.Modular(
112
- model.parameters(),
113
- tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
114
- tz.m.Debias(0.9, 0.999),
115
- )
111
+ ```python
112
+ opt = tz.Modular(
113
+ model.parameters(),
114
+ tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
115
+ tz.m.Debias(0.9, 0.999),
116
+ )
117
+ ```
116
118
  """
117
- _USES_LR = True
119
+
118
120
  def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
119
- defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
120
- if self._USES_LR: defaults['lr'] = lr
121
+ defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
121
122
  super().__init__(defaults, uses_grad=False)
122
123
 
123
124
  @torch.no_grad
124
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
125
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
125
126
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
126
- s = self.settings[params[0]]
127
- lerp = s['lerp']
128
- nesterov = s['nesterov']
127
+ fs = settings[0]
129
128
 
130
- if self._USES_LR:
131
- lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
132
-
133
- else:
134
- lr=None
135
- momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
129
+ lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
136
130
 
137
131
  return msam_(
138
132
  TensorList(tensors),
@@ -142,16 +136,16 @@ class MSAM(Transform):
142
136
  lr=lr,
143
137
  rho=rho,
144
138
  weight_decay=weight_decay,
145
- nesterov=nesterov,
146
- lerp=lerp,
139
+ nesterov=fs['nesterov'],
140
+ lerp=fs['lerp'],
147
141
 
148
142
  # inner args
149
- inner=self.children.get("modules", None),
150
- grads=grads,
143
+ inner=None,
144
+ objective=None,
151
145
  )
152
146
 
153
147
 
154
- class MSAMObjective(MSAM):
148
+ class MSAM(Transform):
155
149
  """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
156
150
 
157
151
  Note:
@@ -160,7 +154,7 @@ class MSAMObjective(MSAM):
160
154
  to an incorrect update rule.
161
155
 
162
156
  Args:
163
- modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
157
+ modules (Chainable): modules that will optimize the MSAM objective. Make sure ``tz.m.LR`` is one of them.
164
158
  momentum (float, optional): momentum (beta). Defaults to 0.9.
165
159
  rho (float, optional): perturbation strength. Defaults to 0.3.
166
160
  nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
@@ -169,20 +163,44 @@ class MSAMObjective(MSAM):
169
163
  Defaults to False.
170
164
 
171
165
  Examples:
172
- AdamW-MSAM
173
-
174
- .. code-block:: python
175
-
176
- opt = tz.Modular(
177
- bench.parameters(),
178
- tz.m.MSAMObjective(
179
- [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
180
- rho=1.
181
- )
182
- )
166
+ AdamW-MSAM
167
+
168
+ ```py
169
+ opt = tz.Modular(
170
+ bench.parameters(),
171
+ tz.m.MSAMObjective(
172
+ [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
173
+ rho=1.
174
+ )
175
+ )
176
+ ```
183
177
  """
184
- _USES_LR = False
185
178
  def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
186
- super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
179
+ defaults = dict(momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
180
+ super().__init__(defaults)
181
+
187
182
  self.set_child('modules', modules)
188
183
 
184
+
185
+ @torch.no_grad
186
+ def apply_states(self, objective, states, settings):
187
+ velocity = unpack_states(states, objective.params, 'velocity', cls=TensorList)
188
+ fs = settings[0]
189
+
190
+ momentum, rho, weight_decay = unpack_dicts(settings, 'momentum', 'rho', 'weight_decay', cls=NumberList)
191
+
192
+ return msam_(
193
+ TensorList(objective.get_updates()),
194
+ params=TensorList(objective.params),
195
+ velocity_=velocity,
196
+ momentum=momentum,
197
+ lr=None,
198
+ rho=rho,
199
+ weight_decay=weight_decay,
200
+ nesterov=fs['nesterov'],
201
+ lerp=fs['lerp'],
202
+
203
+ # inner args
204
+ inner=self.children["modules"],
205
+ objective=objective,
206
+ )
@@ -1,14 +1,11 @@
1
1
  from operator import itemgetter
2
2
  import math
3
- import warnings
4
- from collections.abc import Iterable, Sequence
5
- from typing import Literal
3
+ from collections.abc import Iterable
6
4
 
7
5
  import torch
8
6
 
9
- from ...core import Modular, TensorwiseTransform, Target, Transform
10
- from ...utils import enable_compilation
11
-
7
+ from ...core import TensorTransform, Transform
8
+ from ...linalg.orthogonalize import orthogonalize as _orthogonalize, OrthogonalizeMethod
12
9
 
13
10
  def reverse_dims(t:torch.Tensor):
14
11
  return t.permute(*reversed(range(t.ndim)))
@@ -17,136 +14,69 @@ def _is_at_least_2d(p: torch.Tensor):
17
14
  if (p.ndim >= 2) and (p.size(0) > 1) and (p.size(1) > 1): return True
18
15
  return False
19
16
 
20
- # stolen from:
21
- # https://github.com/KellerJordan/Muon/blob/master/muon.py
22
- # actually at this stage its a frankenstein
23
- @enable_compilation
24
- def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
25
- """
26
- Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
27
-
28
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
32
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
34
- performance at all relative to UV^T, where USV^T = G is the SVD.
35
- """
36
- assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
37
- a, b, c = (3.4445, -4.7750, 2.0315)
38
- X = G.bfloat16()
39
- if G.size(-2) > G.size(-1):
40
- X = X.mT
41
-
42
- # Ensure spectral norm is at most 1
43
- X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
44
- # Perform the NS iterations
45
- for _ in range(steps):
46
- A = X @ X.mT
47
- B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
48
- X = a * X + B @ X
49
-
50
- if G.size(-2) > G.size(-1):
51
- X = X.mT
52
- return X
53
-
54
- # stolen from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
55
- # Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
56
- # Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
57
- @torch.no_grad
58
- def _svd_orthogonalize(G: torch.Tensor, warn_fail=True) -> torch.Tensor:
59
- """
60
- Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
61
- """
62
- X = G.view(G.shape[0], -1)
63
-
64
- t = False
65
- if X.size(0) > X.size(1):
66
- X = X.T
67
- t = True
68
-
69
- orth_X: torch.Tensor | None = None
70
- try:
71
- u, s, vt = torch.linalg.svd(X, full_matrices=False) # pylint:disable=not-callable
72
- orth_X = u @ vt
73
- except RuntimeError:
74
- # if warn: logging.warning('Failed to perform SVD, adding some noise.')
75
- try:
76
- u, s, v = torch.svd_lowrank(
77
- X,
78
- q=1, # assume rank is at least 1
79
- M=1e-4 * X.mean() * torch.randn_like(X))
80
- orth_X = u @ v.T
81
- except RuntimeError:
82
- if warn_fail: warnings.warn(('Failed to perform SVD with noise,'
83
- ' skipping gradient orthogonalisation'))
84
- if orth_X is not None:
85
- if t: orth_X = orth_X.T
86
- return orth_X.view_as(G)
87
-
88
- return G # fail
17
+ def _orthogonalize_format(
18
+ tensor: torch.Tensor,
19
+ method: OrthogonalizeMethod,
20
+ channel_first: bool,
21
+ ):
22
+ if channel_first:
23
+ return reverse_dims(_orthogonalize(reverse_dims(tensor), method=method))
89
24
 
25
+ return _orthogonalize(tensor, method=method)
90
26
 
91
27
  @torch.no_grad
92
- def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, batch_first):
93
- """batch first means it applies to last 2 dims, otherwise to 1st two dims"""
28
+ def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, channel_first: bool):
29
+ """``channel_first`` means it applies to first two dims, otherwise to last two dims"""
94
30
  # this is from https://github.com/leloykun/adaptive-muon
95
31
  # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
96
- if batch_first: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
97
- else: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
32
+ if channel_first: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
33
+ else: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
98
34
  return X
99
35
 
100
36
 
101
37
  # code from
102
38
  # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
103
- def adjust_lr_for_muon(lr, param_shape):
104
- A, B = param_shape[:2]
39
+ def adjust_lr_for_muon(lr, param_shape, channel_first:bool):
40
+ if channel_first: A, B = param_shape[:2]
41
+ else: A, B = param_shape[-2:]
42
+
105
43
  # We adjust the learning rate and weight decay based on the size of the parameter matrix
106
44
  # as describted in the paper
107
45
  adjusted_ratio = 0.2 * math.sqrt(max(A, B))
108
46
  adjusted_lr = lr * adjusted_ratio
109
47
  return adjusted_lr
110
48
 
111
- def _orthogonalize_tensor(
112
- tensor: torch.Tensor,
113
- steps: int = 5,
114
- method: Literal["newton-schulz", "svd"] = "newton-schulz",
115
- ):
116
- if method == 'newton-schulz': return reverse_dims(zeropower_via_newtonschulz5(reverse_dims(tensor), steps)).type_as(tensor)
117
- if method == 'svd': return _svd_orthogonalize(tensor, False)
118
- raise ValueError(method)
119
-
120
49
 
121
50
  def orthogonalize_grads_(
122
51
  params: Iterable[torch.Tensor],
123
- steps: int = 5,
124
52
  dual_norm_correction=False,
125
- method: Literal["newton-schulz", "svd"] = "newton-schulz",
53
+ method: OrthogonalizeMethod = "newtonschulz",
54
+ channel_first:bool=True,
126
55
  ):
127
- """Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
56
+ """Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.
128
57
 
129
58
  This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).
130
59
 
131
60
  Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
132
61
  Args:
133
62
  params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
134
- steps (int, optional):
135
- The number of Newton-Schulz iterations to run. Defaults to 5.
136
63
  dual_norm_correction (bool, optional):
137
64
  enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
138
65
  method (str, optional):
139
66
  Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
67
+ channel_first (bool, optional):
68
+ if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
69
+ are considered batch dimensions.
140
70
  """
141
71
  for p in params:
142
72
  if (p.grad is not None) and _is_at_least_2d(p.grad):
143
- X = _orthogonalize_tensor(p.grad, steps, method)
144
- if dual_norm_correction: X = _dual_norm_correction(X, p.grad, batch_first=False)
73
+ X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
74
+ if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
145
75
  p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
146
76
 
147
77
 
148
78
 
149
- class Orthogonalize(TensorwiseTransform):
79
+ class Orthogonalize(TensorTransform):
150
80
  """Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.
151
81
 
152
82
  To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
@@ -156,16 +86,15 @@ class Orthogonalize(TensorwiseTransform):
156
86
  To make Muon, use Split with Adam on 1d params
157
87
 
158
88
  Args:
159
- ns_steps (int, optional):
160
- The number of Newton-Schulz iterations to run. Defaults to 5.
161
89
  adjust_lr (bool, optional):
162
90
  Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
163
91
  dual_norm_correction (bool, optional):
164
92
  enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
165
93
  method (str, optional):
166
- Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
167
- target (str, optional):
168
- what to set on var.
94
+ Newton-Schulz is very fast, SVD is slow but can be more precise.
95
+ channel_first (bool, optional):
96
+ if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
97
+ are considered batch dimensions.
169
98
 
170
99
  ## Examples:
171
100
 
@@ -190,56 +119,62 @@ class Orthogonalize(TensorwiseTransform):
190
119
  Reference:
191
120
  Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
192
121
  """
193
- def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
194
- method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
195
- defaults = dict(orthogonalize=True, ns_steps=ns_steps, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower())
196
- super().__init__(uses_grad=False, defaults=defaults, target=target)
122
+ def __init__(self, adjust_lr=False, dual_norm_correction=False,
123
+ method: OrthogonalizeMethod = 'newtonschulz', channel_first:bool=True):
124
+ defaults = dict(orthogonalize=True, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower(), channel_first=channel_first)
125
+ super().__init__(defaults=defaults)
197
126
 
198
127
  @torch.no_grad
199
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
200
- orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
201
- 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
128
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
129
+ orthogonalize, dual_norm_correction, adjust_lr, method, channel_first = itemgetter(
130
+ 'orthogonalize', 'dual_norm_correction', 'adjust_lr', 'method', 'channel_first')(setting)
202
131
 
203
132
  if not orthogonalize: return tensor
204
133
 
205
134
  if _is_at_least_2d(tensor):
206
135
 
207
- X = _orthogonalize_tensor(tensor, ns_steps, method)
136
+ X = _orthogonalize_format(tensor, method, channel_first=channel_first)
208
137
 
209
138
  if dual_norm_correction:
210
- X = _dual_norm_correction(X, tensor, batch_first=False)
139
+ X = _dual_norm_correction(X, tensor, channel_first=channel_first)
211
140
 
212
141
  if adjust_lr:
213
- X.mul_(adjust_lr_for_muon(1, param.shape))
142
+ X.mul_(adjust_lr_for_muon(1, param.shape, channel_first=channel_first))
214
143
 
215
144
  return X.view_as(param)
216
145
 
217
146
  return tensor
218
147
 
219
148
 
220
- class DualNormCorrection(TensorwiseTransform):
149
+ class DualNormCorrection(TensorTransform):
221
150
  """Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
222
151
  Orthogonalize already has this built in with the `dual_norm_correction` setting."""
223
- def __init__(self, target: Target='update'):
224
- super().__init__({}, uses_grad=True, target=target)
152
+ def __init__(self, channel_first: bool = True):
153
+ defaults = dict(channel_first=channel_first)
154
+ super().__init__(defaults)
225
155
 
226
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
156
+ @torch.no_grad
157
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
227
158
  assert grad is not None
228
159
  if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
229
- return _dual_norm_correction(tensor, grad, batch_first=False)
160
+ return _dual_norm_correction(tensor, grad, channel_first=setting["channel_first"])
230
161
  return tensor
231
162
 
232
163
 
233
164
  class MuonAdjustLR(Transform):
234
165
  """LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
235
- Orthogonalize already has this built in with the `adjust_lr` setting, however you might want to move this to be later in the chain."""
236
- def __init__(self, alpha: float = 1, target: Target='update'):
237
- defaults = dict(alpha=alpha)
238
- super().__init__(defaults=defaults, uses_grad=False, target=target)
166
+ Orthogonalize already has this built in with the ``adjust_lr`` setting, however you might want to move this to be later in the chain."""
167
+ def __init__(self, channel_first: bool = True, alpha: float = 1):
168
+ defaults = dict(channel_first=channel_first, alpha=alpha)
169
+ super().__init__(defaults=defaults)
239
170
 
240
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
171
+ @torch.no_grad
172
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
241
173
  alphas = [s['alpha'] for s in settings]
242
- tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
174
+ channel_first = [s["channel_first=channel_first"] for s in settings]
175
+ tensors_alphas = [
176
+ (t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t)
177
+ ]
243
178
  tensors = [i[0] for i in tensors_alphas]
244
179
  a = [i[1] for i in alphas]
245
180
  torch._foreach_mul_(tensors, a)