torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,79 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
+
6
+
7
+ def mars_correction_(
8
+ tensors_: TensorList,
9
+ prev_: TensorList,
10
+ beta: float | NumberList,
11
+ scaling: float | NumberList,
12
+ max_norm: float | NumberList | None,
13
+ ):
14
+ dg = (tensors_ - prev_).mul_(scaling * beta / (1-beta))
15
+ prev_.copy_(tensors_)
16
+
17
+ c = tensors_.add_(dg)
18
+ if max_norm is not None:
19
+ c.clip_norm_(max=max_norm, tensorwise=False)
20
+
21
+ return c
22
+
23
+ class MARSCorrection(Transform):
24
+ """MARS variance reduction correction.
25
+
26
+ Place any other momentum-based optimizer after this,
27
+ make sure ``beta`` parameter matches with momentum in the optimizer.
28
+
29
+ Args:
30
+ beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
31
+ scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
32
+ max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
33
+
34
+ ## Examples:
35
+
36
+ Mars-AdamW
37
+ ```python
38
+ optimizer = tz.Modular(
39
+ model.parameters(),
40
+ tz.m.MARSCorrection(beta=0.95),
41
+ tz.m.Adam(beta1=0.95, beta2=0.99),
42
+ tz.m.WeightDecay(1e-3),
43
+ tz.m.LR(0.1)
44
+ )
45
+ ```
46
+
47
+ Mars-Lion
48
+ ```python
49
+ optimizer = tz.Modular(
50
+ model.parameters(),
51
+ tz.m.MARSCorrection(beta=0.9),
52
+ tz.m.Lion(beta1=0.9),
53
+ tz.m.LR(0.1)
54
+ )
55
+ ```
56
+
57
+ """
58
+ def __init__(
59
+ self,
60
+ beta: float = 0.9,
61
+ scaling: float = 0.025,
62
+ max_norm: float | None = 1,
63
+ ):
64
+ defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
65
+ super().__init__(defaults, uses_grad=False)
66
+
67
+ @torch.no_grad
68
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
69
+ prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
70
+ beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
71
+ max_norm = settings[0]['max_norm']
72
+
73
+ return mars_correction_(
74
+ tensors_=TensorList(tensors),
75
+ prev_=prev,
76
+ beta=beta,
77
+ scaling=scaling,
78
+ max_norm=max_norm,
79
+ )
@@ -0,0 +1,146 @@
1
+ from typing import Literal
2
+ from collections.abc import Callable
3
+ import torch
4
+
5
+ from ...core import Module, apply_transform, Chainable
6
+ from ...utils import NumberList, TensorList, as_tensorlist
7
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
+ from ..functional import initial_step_size
9
+
10
+
11
+ class MatrixMomentum(Module):
12
+ """Second order momentum method.
13
+
14
+ Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
15
+
16
+ Notes:
17
+ - ``mu`` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - ``tz.m.AdaptiveMatrixMomentum``, and it works well without having to tune ``mu``, however the adaptive version doesn't work on stochastic objectives.
18
+
19
+ - In most cases ``MatrixMomentum`` should be the first module in the chain because it relies on autograd.
20
+
21
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument.
22
+
23
+ Args:
24
+ mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
25
+ hvp_method (str, optional):
26
+ Determines how Hessian-vector products are evaluated.
27
+
28
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
29
+ This requires creating a graph for the gradient.
30
+ - ``"forward"``: Use a forward finite difference formula to
31
+ approximate the HVP. This requires one extra gradient evaluation.
32
+ - ``"central"``: Use a central finite difference formula for a
33
+ more accurate HVP approximation. This requires two extra
34
+ gradient evaluations.
35
+ Defaults to "autograd".
36
+ h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
37
+ hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
38
+
39
+ Reference:
40
+ Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ lr:float,
46
+ mu=0.1,
47
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
+ h: float = 1e-3,
49
+ adaptive:bool = False,
50
+ adapt_freq: int | None = None,
51
+ hvp_tfm: Chainable | None = None,
52
+ ):
53
+ defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
54
+ super().__init__(defaults)
55
+
56
+ if hvp_tfm is not None:
57
+ self.set_child('hvp_tfm', hvp_tfm)
58
+
59
+ def reset_for_online(self):
60
+ super().reset_for_online()
61
+ self.clear_state_keys('p_prev')
62
+
63
+ @torch.no_grad
64
+ def update(self, var):
65
+ assert var.closure is not None
66
+ p = TensorList(var.params)
67
+ p_prev = self.get_state(p, 'p_prev', init=var.params)
68
+
69
+ hvp_method = self.defaults['hvp_method']
70
+ h = self.defaults['h']
71
+ step = self.global_state.get("step", 0)
72
+ self.global_state["step"] = step + 1
73
+
74
+ if step > 0:
75
+ s = p - p_prev
76
+
77
+ Hs, _ = self.Hvp(s, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
78
+ Hs = [t.detach() for t in Hs]
79
+
80
+ if 'hvp_tfm' in self.children:
81
+ Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
82
+
83
+ self.store(p, ("Hs", "s"), (Hs, s))
84
+
85
+ # -------------------------------- adaptive mu ------------------------------- #
86
+ if self.defaults["adaptive"]:
87
+ g = TensorList(var.get_grad())
88
+
89
+ if self.defaults["adapt_freq"] is None:
90
+ # ---------------------------- deterministic case ---------------------------- #
91
+ g_prev = self.get_state(var.params, "g_prev", cls=TensorList)
92
+ y = g - g_prev
93
+ g_prev.copy_(g)
94
+ denom = y.global_vector_norm()
95
+ denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
96
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
97
+
98
+ else:
99
+ # -------------------------------- stochastic -------------------------------- #
100
+ adapt_freq = self.defaults["adapt_freq"]
101
+
102
+ # we start on 1nd step, and want to adapt when we start, so use (step - 1)
103
+ if (step - 1) % adapt_freq == 0:
104
+ assert var.closure is not None
105
+ params = TensorList(var.params)
106
+ p_cur = params.clone()
107
+
108
+ # move to previous params and evaluate p_prev with current mini-batch
109
+ params.copy_(self.get_state(var.params, 'p_prev'))
110
+ with torch.enable_grad():
111
+ var.closure()
112
+ g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
113
+ y = g - g_prev
114
+
115
+ # move back to current params
116
+ params.copy_(p_cur)
117
+
118
+ denom = y.global_vector_norm()
119
+ denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
120
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
121
+
122
+ torch._foreach_copy_(p_prev, var.params)
123
+
124
+ @torch.no_grad
125
+ def apply(self, var):
126
+ update = TensorList(var.get_update())
127
+ lr,mu = self.get_settings(var.params, "lr", 'mu', cls=NumberList)
128
+
129
+ if "mu_mul" in self.global_state:
130
+ mu = mu * self.global_state["mu_mul"]
131
+
132
+ # --------------------------------- 1st step --------------------------------- #
133
+ # p_prev is not available so make a small step
134
+ step = self.global_state["step"]
135
+ if step == 1:
136
+ if self.defaults["adaptive"]: self.get_state(var.params, "g_prev", init=var.get_grad())
137
+ update.mul_(lr) # separate so that initial_step_size can clip correctly
138
+ update.mul_(initial_step_size(update, 1e-7))
139
+ return var
140
+
141
+ # -------------------------- matrix momentum update -------------------------- #
142
+ s, Hs = self.get_state(var.params, 's', 'Hs', cls=TensorList)
143
+
144
+ update.mul_(lr).sub_(s).add_(Hs*mu)
145
+ var.update = update
146
+ return var
@@ -0,0 +1,188 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Module, Target, Transform, apply_transform
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
7
+ from ..functional import ema_
8
+ from ..momentum.momentum import nag_
9
+
10
+
11
+ def msam_(
12
+ tensors: TensorList,
13
+ params: TensorList,
14
+ velocity_: TensorList,
15
+ momentum: float | NumberList,
16
+ lr: NumberList | None,
17
+ rho: float | NumberList,
18
+ weight_decay: float | NumberList,
19
+ nesterov: bool = False,
20
+ lerp: bool = False,
21
+
22
+ # inner args
23
+ inner: Module | None = None,
24
+ grads: list[torch.Tensor] | None = None,
25
+ ):
26
+ # weights w and wh, momentum μ, perturbation strength ρ
27
+ # w = wh + rho * v / ||v||
28
+ # v1 = μv + g
29
+ # w1 = w - lr*v1
30
+ # wh1 = w1 - rho * v1 / ||v1||
31
+
32
+ # w1 = wh + rho * v / ||v|| - lr*v1
33
+ # vn = rho * v / ||v||
34
+ # v1n = rho * v1 / ||v1||
35
+ # wh1 = wh + vn - lr*v1 - v1n
36
+
37
+ # the update is
38
+ # vn - lr*v1 - v1n
39
+
40
+ # we track ascent direction so it becomes lr*v1 + v1n - vn
41
+
42
+ # can't really decouple it from lr
43
+ # but at least it is now expressed as function of g
44
+
45
+ denom = velocity_.global_vector_norm() / rho
46
+ denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
47
+ vn = velocity_ / denom
48
+
49
+ mom_ = nag_ if nesterov else ema_
50
+ velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
51
+
52
+ denom = velocity_.global_vector_norm() / rho
53
+ denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
54
+ v1n = velocity_ / denom
55
+
56
+ if inner is not None:
57
+ assert params is not None
58
+ inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
59
+
60
+ else:
61
+ assert lr is not None
62
+ inner_update = velocity_ * lr
63
+
64
+ update = inner_update.add_(v1n).sub_(vn)
65
+
66
+ if generic_ne(weight_decay, 0):
67
+ wd = (params + vn).mul_(weight_decay)
68
+ update.add_(wd)
69
+
70
+ return update
71
+
72
+ class MSAM(Transform):
73
+ """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
74
+
75
+ This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
76
+ replacement for momentum strategies in other optimizers.
77
+
78
+ To combine MSAM with other optimizers in the way done in the official implementation,
79
+ e.g. to make Adam_MSAM, use ``tz.m.MSAMObjective`` module.
80
+
81
+ Note
82
+ MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
83
+ To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
84
+
85
+ Args:
86
+ lr (float): learning rate. Adding this module adds support for learning rate schedulers.
87
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
88
+ rho (float, optional): perturbation strength. Defaults to 0.3.
89
+ weight_decay (float, optional):
90
+ weight decay. It is applied to perturbed parameters, so it is differnet
91
+ from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
92
+ nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
93
+ lerp (bool, optional):
94
+ whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
95
+
96
+ Examples:
97
+ MSAM
98
+
99
+ .. code-block:: python
100
+
101
+ opt = tz.Modular(
102
+ model.parameters(),
103
+ tz.m.MSAM(1e-3)
104
+ )
105
+
106
+ Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
107
+ To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
108
+
109
+ .. code-block:: python
110
+
111
+ opt = tz.Modular(
112
+ model.parameters(),
113
+ tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
114
+ tz.m.Debias(0.9, 0.999),
115
+ )
116
+ """
117
+ _USES_LR = True
118
+ def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
119
+ defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
120
+ if self._USES_LR: defaults['lr'] = lr
121
+ super().__init__(defaults, uses_grad=False)
122
+
123
+ @torch.no_grad
124
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
125
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
126
+ s = self.settings[params[0]]
127
+ lerp = s['lerp']
128
+ nesterov = s['nesterov']
129
+
130
+ if self._USES_LR:
131
+ lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
132
+
133
+ else:
134
+ lr=None
135
+ momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
136
+
137
+ return msam_(
138
+ TensorList(tensors),
139
+ params=TensorList(params),
140
+ velocity_=velocity,
141
+ momentum=momentum,
142
+ lr=lr,
143
+ rho=rho,
144
+ weight_decay=weight_decay,
145
+ nesterov=nesterov,
146
+ lerp=lerp,
147
+
148
+ # inner args
149
+ inner=self.children.get("modules", None),
150
+ grads=grads,
151
+ )
152
+
153
+
154
+ class MSAMObjective(MSAM):
155
+ """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
156
+
157
+ Note:
158
+ Please make sure to place ``tz.m.LR`` inside the ``modules`` argument. For example,
159
+ ``tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])``. Putting LR after MSAM will lead
160
+ to an incorrect update rule.
161
+
162
+ Args:
163
+ modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
164
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
165
+ rho (float, optional): perturbation strength. Defaults to 0.3.
166
+ nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
167
+ lerp (bool, optional):
168
+ whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
169
+ Defaults to False.
170
+
171
+ Examples:
172
+ AdamW-MSAM
173
+
174
+ .. code-block:: python
175
+
176
+ opt = tz.Modular(
177
+ bench.parameters(),
178
+ tz.m.MSAMObjective(
179
+ [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
180
+ rho=1.
181
+ )
182
+ )
183
+ """
184
+ _USES_LR = False
185
+ def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
186
+ super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
187
+ self.set_child('modules', modules)
188
+
@@ -19,6 +19,7 @@ def _is_at_least_2d(p: torch.Tensor):
19
19
 
20
20
  # stolen from:
21
21
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
22
+ # actually at this stage its a frankenstein
22
23
  @enable_compilation
23
24
  def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
24
25
  """
@@ -152,7 +153,7 @@ class Orthogonalize(TensorwiseTransform):
152
153
  The Muon page says that embeddings and classifier heads should not be orthogonalized.
153
154
  Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
154
155
 
155
- To make Muon, use Split with Adam on 1d params: TODO code example.
156
+ To make Muon, use Split with Adam on 1d params
156
157
 
157
158
  Args:
158
159
  ns_steps (int, optional):
@@ -165,6 +166,29 @@ class Orthogonalize(TensorwiseTransform):
165
166
  Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
166
167
  target (str, optional):
167
168
  what to set on var.
169
+
170
+ ## Examples:
171
+
172
+ standard Muon with Adam fallback
173
+ ```py
174
+ opt = tz.Modular(
175
+ model.head.parameters(),
176
+ tz.m.Split(
177
+ # apply muon only to 2D+ parameters
178
+ filter = lambda t: t.ndim >= 2,
179
+ true = [
180
+ tz.m.HeavyBall(),
181
+ tz.m.Orthogonalize(),
182
+ tz.m.LR(1e-2),
183
+ ],
184
+ false = tz.m.Adam()
185
+ ),
186
+ tz.m.LR(1e-2)
187
+ )
188
+ ```
189
+
190
+ Reference:
191
+ Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
168
192
  """
169
193
  def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
170
194
  method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
@@ -172,9 +196,9 @@ class Orthogonalize(TensorwiseTransform):
172
196
  super().__init__(uses_grad=False, defaults=defaults, target=target)
173
197
 
174
198
  @torch.no_grad
175
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
199
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
176
200
  orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
177
- 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(settings)
201
+ 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
178
202
 
179
203
  if not orthogonalize: return tensor
180
204
 
@@ -199,7 +223,7 @@ class DualNormCorrection(TensorwiseTransform):
199
223
  def __init__(self, target: Target='update'):
200
224
  super().__init__({}, uses_grad=True, target=target)
201
225
 
202
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
226
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
203
227
  assert grad is not None
204
228
  if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
205
229
  return _dual_norm_correction(tensor, grad, batch_first=False)
@@ -213,7 +237,7 @@ class MuonAdjustLR(Transform):
213
237
  defaults = dict(alpha=alpha)
214
238
  super().__init__(defaults=defaults, uses_grad=False, target=target)
215
239
 
216
- def apply(self, tensors, params, grads, loss, states, settings):
240
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
217
241
  alphas = [s['alpha'] for s in settings]
218
242
  tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
219
243
  tensors = [i[0] for i in tensors_alphas]
@@ -0,0 +1,175 @@
1
+ import torch
2
+ from ...core import Module, Chainable, apply_transform
3
+
4
+ from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
+ from ...utils import vec_to_tensors, TensorList
6
+ from ...utils.linalg import linear_operator
7
+ from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
8
+
9
+ class NaturalGradient(Module):
10
+ """Natural gradient approximated via empirical fisher information matrix.
11
+
12
+ To use this, either pass vector of per-sample losses to the step method, or make sure
13
+ the closure returns it. Gradients will be calculated via batched autograd within this module,
14
+ you don't need to implement the backward pass. When using closure, please add the ``backward`` argument,
15
+ it will always be False but it is required. See below for an example.
16
+
17
+ Note:
18
+ Empirical fisher information matrix may give a really bad approximation in some cases.
19
+ If that is the case, set ``sqrt`` to True to perform whitening instead, which is way more robust.
20
+
21
+ Args:
22
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
23
+ sqrt (bool, optional):
24
+ if True, uses square root of empirical fisher information matrix. Both EFIM and it's square
25
+ root can be calculated and stored efficiently without ndim^2 memory. Square root
26
+ whitens the gradient and often performs much better, especially when you try to use NGD
27
+ with a vector that isn't strictly per-sample gradients, but rather for example different losses.
28
+ gn_grad (bool, optional):
29
+ if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
30
+ and is equivalent to squaring the values. This way you can solve least-squares
31
+ objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients.
32
+ This has an effect when ``sqrt=True``, and affects the ``grad`` attribute.
33
+ Defaults to False.
34
+ batched (bool, optional): whether to use vmapping. Defaults to True.
35
+
36
+ Examples:
37
+
38
+ training a neural network:
39
+ ```python
40
+ X = torch.randn(64, 20)
41
+ y = torch.randn(64, 10)
42
+
43
+ model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
44
+ opt = tz.Modular(
45
+ model.parameters(),
46
+ tz.m.NaturalGradient(),
47
+ tz.m.LR(3e-2)
48
+ )
49
+
50
+ for i in range(100):
51
+ y_hat = model(X) # (64, 10)
52
+ losses = (y_hat - y).pow(2).mean(0) # (10, )
53
+ opt.step(loss=losses)
54
+ if i % 10 == 0:
55
+ print(f'{losses.mean() = }')
56
+ ```
57
+
58
+ training a neural network - closure version
59
+ ```python
60
+ X = torch.randn(64, 20)
61
+ y = torch.randn(64, 10)
62
+
63
+ model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
64
+ opt = tz.Modular(
65
+ model.parameters(),
66
+ tz.m.NaturalGradient(),
67
+ tz.m.LR(3e-2)
68
+ )
69
+
70
+ def closure(backward=True):
71
+ y_hat = model(X) # (64, 10)
72
+ return (y_hat - y).pow(2).mean(0) # (10, )
73
+
74
+ for i in range(100):
75
+ losses = opt.step(closure)
76
+ if i % 10 == 0:
77
+ print(f'{losses.mean() = }')
78
+ ```
79
+
80
+ minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:
81
+ ```python
82
+ def rosenbrock(X):
83
+ x1, x2 = X
84
+ return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
85
+
86
+ X = torch.tensor([-1.1, 2.5], requires_grad=True)
87
+ opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
88
+
89
+ for iter in range(200):
90
+ losses = rosenbrock(X)
91
+ opt.step(loss=losses)
92
+ if iter % 20 == 0:
93
+ print(f'{losses.mean() = }')
94
+ ```
95
+ """
96
+ def __init__(self, reg:float = 1e-8, sqrt:bool=False, gn_grad:bool=False, batched:bool=True, ):
97
+ super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
98
+
99
+ @torch.no_grad
100
+ def update(self, var):
101
+ params = var.params
102
+ batched = self.defaults['batched']
103
+ gn_grad = self.defaults['gn_grad']
104
+
105
+ closure = var.closure
106
+ assert closure is not None
107
+
108
+ with torch.enable_grad():
109
+ f = var.get_loss(backward=False) # n_out
110
+ assert isinstance(f, torch.Tensor)
111
+ G_list = jacobian_wrt([f.ravel()], params, batched=batched)
112
+
113
+ var.loss = f.sum()
114
+ G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
115
+
116
+ if gn_grad:
117
+ g = self.global_state["g"] = G.H @ f.detach()
118
+
119
+ else:
120
+ g = self.global_state["g"] = G.sum(0)
121
+
122
+ var.grad = vec_to_tensors(g, params)
123
+
124
+ # set closure to calculate scalar value for line searches etc
125
+ if var.closure is not None:
126
+ def ngd_closure(backward=True):
127
+ if backward:
128
+ var.zero_grad()
129
+ with torch.enable_grad():
130
+ loss = closure(False)
131
+ if gn_grad: loss = loss.pow(2)
132
+ loss = loss.sum()
133
+ loss.backward()
134
+ return loss
135
+
136
+ loss = closure(False)
137
+ if gn_grad: loss = loss.pow(2)
138
+ return loss.sum()
139
+
140
+ var.closure = ngd_closure
141
+
142
+ @torch.no_grad
143
+ def apply(self, var):
144
+ params = var.params
145
+ reg = self.defaults['reg']
146
+ sqrt = self.defaults['sqrt']
147
+
148
+ G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
149
+
150
+ if sqrt:
151
+ # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
152
+ # but it computes it through eigendecompotision
153
+ U, L = lm_adagrad_update(G.H, reg, 0)
154
+ if U is None or L is None: return var
155
+
156
+ v = lm_adagrad_apply(self.global_state["g"], U, L)
157
+ var.update = vec_to_tensors(v, params)
158
+ return var
159
+
160
+ GGT = G @ G.H # (n_samples, n_samples)
161
+
162
+ if reg != 0:
163
+ GGT.add_(torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype).mul_(reg))
164
+
165
+ z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
166
+ v = G.H @ z
167
+
168
+ var.update = vec_to_tensors(v, params)
169
+ return var
170
+
171
+
172
+ def get_H(self, var):
173
+ if "G" not in self.global_state: return linear_operator.ScaledIdentity()
174
+ G = self.global_state['G']
175
+ return linear_operator.AtA(G)
@@ -36,7 +36,7 @@ class OrthoGrad(Transform):
36
36
  defaults = dict(eps=eps, renormalize=renormalize)
37
37
  super().__init__(defaults, uses_grad=False, target=target)
38
38
 
39
- def apply(self, tensors, params, grads, loss, states, settings):
39
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
40
  eps = settings[0]['eps']
41
41
  renormalize = settings[0]['renormalize']
42
42