torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -40,7 +40,9 @@ def rmsprop_(
40
40
  return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
41
41
 
42
42
  class RMSprop(Transform):
43
- """Divides graient by EMA of gradient squares. Matches pytorch RMSprop if "init" is set to "zeros".
43
+ """Divides graient by EMA of gradient squares.
44
+
45
+ This implementation is identical to :code:`torch.optim.RMSprop`.
44
46
 
45
47
  Args:
46
48
  smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
@@ -50,7 +52,8 @@ class RMSprop(Transform):
50
52
  amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
51
53
  pow (float, optional): power used in second momentum power and root. Defaults to 2.
52
54
  init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
53
- inner (Chainable | None, optional): Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
55
+ inner (Chainable | None, optional):
56
+ Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
54
57
  """
55
58
  def __init__(
56
59
  self,
@@ -60,7 +63,7 @@ class RMSprop(Transform):
60
63
  debiased: bool = False,
61
64
  amsgrad: bool = False,
62
65
  pow: float = 2,
63
- init: Literal["zeros", "update"] = "update",
66
+ init: Literal["zeros", "update"] = "zeros",
64
67
  inner: Chainable | None = None,
65
68
  ):
66
69
  defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
@@ -69,7 +72,7 @@ class RMSprop(Transform):
69
72
  if inner is not None:
70
73
  self.set_child('inner', inner)
71
74
 
72
- def apply(self, tensors, params, grads, loss, states, settings):
75
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
73
76
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
74
77
  smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
75
78
  centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
@@ -135,7 +135,8 @@ class Rprop(Transform):
135
135
  Next step, magnitude for that weight won't change.
136
136
 
137
137
  Compared to pytorch this also implements backtracking update when sign changes.
138
- To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
138
+
139
+ This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.
139
140
 
140
141
  Args:
141
142
  nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
@@ -164,7 +165,7 @@ class Rprop(Transform):
164
165
  super().__init__(defaults, uses_grad=False)
165
166
 
166
167
  @torch.no_grad
167
- def apply(self, tensors, params, grads, loss, states, settings):
168
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
168
169
  step = self.global_state.get('step', 0)
169
170
  self.global_state['step'] = step + 1
170
171
 
@@ -223,7 +224,7 @@ class ScaleLRBySignChange(Transform):
223
224
  super().__init__(defaults, uses_grad=use_grad, target=target)
224
225
 
225
226
  @torch.no_grad
226
- def apply(self, tensors, params, grads, loss, states, settings):
227
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
227
228
  step = self.global_state.get('step', 0)
228
229
  self.global_state['step'] = step + 1
229
230
 
@@ -257,8 +258,6 @@ class BacktrackOnSignChange(Transform):
257
258
  This is part of RProp update rule.
258
259
 
259
260
  Args:
260
- normalize (bool, optional): renormalize update after masking. Defaults to False.
261
- eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
262
261
  use_grad (bool, optional):
263
262
  if True, tracks sign change of the gradient,
264
263
  otherwise track sign change of the update. Defaults to True.
@@ -272,7 +271,7 @@ class BacktrackOnSignChange(Transform):
272
271
  super().__init__(defaults, uses_grad=use_grad)
273
272
 
274
273
  @torch.no_grad
275
- def apply(self, tensors, params, grads, loss, states, settings):
274
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
276
275
  step = self.global_state.get('step', 0)
277
276
  self.global_state['step'] = step + 1
278
277
 
@@ -294,12 +293,29 @@ class BacktrackOnSignChange(Transform):
294
293
  return tensors
295
294
 
296
295
  class SignConsistencyMask(Transform):
297
- """0 if sign changed 1 otherwise"""
296
+ """
297
+ Outputs a mask of sign consistency of current and previous inputs.
298
+
299
+ The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
300
+
301
+ Examples:
302
+
303
+ GD that skips update for weights where gradient sign changed compared to previous gradient.
304
+
305
+ .. code-block:: python
306
+
307
+ opt = tz.Modular(
308
+ model.parameters(),
309
+ tz.m.Mul(tz.m.SignConsistencyMask()),
310
+ tz.m.LR(1e-2)
311
+ )
312
+
313
+ """
298
314
  def __init__(self,target: Target = 'update'):
299
315
  super().__init__({}, uses_grad=False, target = target)
300
316
 
301
317
  @torch.no_grad
302
- def apply(self, tensors, params, grads, loss, states, settings):
318
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
303
319
  prev = unpack_states(states, tensors, 'prev', cls=TensorList)
304
320
  mask = prev.mul_(tensors).gt_(0)
305
321
  prev.copy_(tensors)
@@ -307,7 +323,23 @@ class SignConsistencyMask(Transform):
307
323
 
308
324
 
309
325
  class SignConsistencyLRs(Transform):
310
- """LR for each weight is increased when two consequtive update signs are the same, decreased otherwise. This returns the LRs themselves."""
326
+ """Outputs per-weight learning rates based on consecutive sign consistency.
327
+
328
+ The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
329
+
330
+ Examples:
331
+
332
+ GD scaled by consecutive gradient sign consistency
333
+
334
+ .. code-block:: python
335
+
336
+ opt = tz.Modular(
337
+ model.parameters(),
338
+ tz.m.Mul(tz.m.SignConsistencyLRs()),
339
+ tz.m.LR(1e-2)
340
+ )
341
+
342
+ """
311
343
  def __init__(
312
344
  self,
313
345
  nplus: float = 1.2,
@@ -321,7 +353,7 @@ class SignConsistencyLRs(Transform):
321
353
  super().__init__(defaults, uses_grad=False, target = target)
322
354
 
323
355
  @torch.no_grad
324
- def apply(self, tensors, params, grads, loss, states, settings):
356
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
325
357
  step = self.global_state.get('step', 0)
326
358
  self.global_state['step'] = step + 1
327
359
 
@@ -0,0 +1,163 @@
1
+ from contextlib import nullcontext
2
+ import torch
3
+ from ...utils import TensorList, NumberList
4
+ from ...core import Module
5
+
6
+
7
+ class SAM(Module):
8
+ """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
9
+
10
+ SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
11
+ It performs two forward and backward passes per step.
12
+
13
+ This implementation modifies the closure to return loss and calculate gradients
14
+ of the SAM objective. All modules after this will use the modified objective.
15
+
16
+ .. note::
17
+ This module requires a closure passed to the optimizer step,
18
+ as it needs to re-evaluate the loss and gradients at two points on each step.
19
+
20
+ Args:
21
+ rho (float, optional): Neighborhood size. Defaults to 0.05.
22
+ p (float, optional): norm of the SAM objective. Defaults to 2.
23
+ asam (bool, optional):
24
+ enables ASAM variant which makes perturbation relative to weight magnitudes.
25
+ ASAM requires a much larger :code:`rho`, like 0.5 or 1.
26
+ The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
27
+ it has larger :code:`rho` by default.
28
+
29
+ Examples:
30
+ SAM-SGD:
31
+
32
+ .. code-block:: python
33
+
34
+ opt = tz.Modular(
35
+ model.parameters(),
36
+ tz.m.SAM(),
37
+ tz.m.LR(1e-2)
38
+ )
39
+
40
+ SAM-Adam:
41
+
42
+ .. code-block:: python
43
+
44
+ opt = tz.Modular(
45
+ model.parameters(),
46
+ tz.m.SAM(),
47
+ tz.m.Adam(),
48
+ tz.m.LR(1e-2)
49
+ )
50
+
51
+ References:
52
+ Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
53
+ """
54
+ def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
55
+ defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
56
+ super().__init__(defaults)
57
+
58
+ @torch.no_grad
59
+ def step(self, var):
60
+
61
+ params = var.params
62
+ closure = var.closure
63
+ zero_grad = var.zero_grad
64
+ if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
65
+ p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
66
+ s = self.defaults
67
+ eps = s['eps']
68
+ asam = s['asam']
69
+
70
+ # 1/p + 1/q = 1
71
+ # okay, authors of SAM paper, I will manually solve your equation
72
+ # so q = -p/(1-p)
73
+ q = -p / (1-p)
74
+ # as a validation for 2 it is -2 / -1 = 2
75
+
76
+ @torch.no_grad
77
+ def sam_closure(backward=True):
78
+ orig_grads = None
79
+ if not backward:
80
+ # if backward is False, make sure this doesn't modify gradients
81
+ # to avoid issues
82
+ orig_grads = [p.grad for p in params]
83
+
84
+ # gradient at initial parameters
85
+ zero_grad()
86
+ with torch.enable_grad():
87
+ closure()
88
+
89
+ grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
90
+ grad_abs = grad.abs()
91
+
92
+ # compute e
93
+ term1 = grad.sign().mul_(rho)
94
+ term2 = grad_abs.pow(q-1)
95
+
96
+ if asam:
97
+ grad_abs.mul_(torch._foreach_abs(params))
98
+
99
+ denom = grad_abs.pow_(q).sum().pow(1/p)
100
+
101
+ e = term1.mul_(term2).div_(denom.clip(min=eps))
102
+
103
+ if asam:
104
+ e.mul_(torch._foreach_pow(params, 2))
105
+
106
+ # calculate loss and gradient approximation of inner problem
107
+ torch._foreach_add_(params, e)
108
+ if backward:
109
+ zero_grad()
110
+ with torch.enable_grad():
111
+ # this sets .grad attributes
112
+ sam_loss = closure()
113
+
114
+ else:
115
+ sam_loss = closure(False)
116
+
117
+ # and restore initial parameters
118
+ torch._foreach_sub_(params, e)
119
+
120
+ if orig_grads is not None:
121
+ for param,orig_grad in zip(params, orig_grads):
122
+ param.grad = orig_grad
123
+
124
+ return sam_loss
125
+
126
+ var.closure = sam_closure
127
+ return var
128
+
129
+ # different class because defaults for SAM are bad for ASAM
130
+ class ASAM(SAM):
131
+ """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
132
+
133
+ SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
134
+ It performs two forward and backward passes per step.
135
+
136
+ This implementation modifies the closure to return loss and calculate gradients
137
+ of the SAM objective. All modules after this will use the modified objective.
138
+
139
+ .. note::
140
+ This module requires a closure passed to the optimizer step,
141
+ as it needs to re-evaluate the loss and gradients at two points on each step.
142
+
143
+ Args:
144
+ rho (float, optional): Neighborhood size. Defaults to 0.05.
145
+ p (float, optional): norm of the SAM objective. Defaults to 2.
146
+
147
+ Examples:
148
+ ASAM-Adam:
149
+
150
+ .. code-block:: python
151
+
152
+ opt = tz.Modular(
153
+ model.parameters(),
154
+ tz.m.ASAM(),
155
+ tz.m.Adam(),
156
+ tz.m.LR(1e-2)
157
+ )
158
+
159
+ References:
160
+ Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
161
+ """
162
+ def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
163
+ super().__init__(rho=rho, p=p, eps=eps, asam=True)
@@ -17,6 +17,7 @@ def update_shampoo_preconditioner_(
17
17
  update_freq: int,
18
18
  exp_override: int | None,
19
19
  beta: float | None,
20
+ reg: float
20
21
  ):
21
22
  for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
22
23
  if accumulator is None: continue
@@ -28,6 +29,8 @@ def update_shampoo_preconditioner_(
28
29
 
29
30
  if step % update_freq == 0:
30
31
  matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
32
+ if reg != 0:
33
+ accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
31
34
  set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
32
35
 
33
36
 
@@ -59,7 +62,7 @@ def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
59
62
  if tensor.shape[sort_idxs[0]] > max_dim:
60
63
  return tensor, None, None
61
64
 
62
- tensor = tensor.permute(*sort_idxs)
65
+ tensor = tensor.permute(*sort_idxs.tolist())
63
66
  flatten_end_idx = 0
64
67
  flat_sizes = []
65
68
  flat_numel = 1
@@ -80,19 +83,27 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
80
83
  if flat_sizes is None: return tensor
81
84
  assert sort_idxs is not None
82
85
  tensor = tensor.unflatten(0, flat_sizes)
83
- return tensor.permute(*np.argsort(sort_idxs))
86
+ return tensor.permute(*np.argsort(sort_idxs).tolist())
84
87
 
85
88
 
86
89
  class Shampoo(Transform):
87
90
  """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
88
91
 
92
+ .. note::
93
+ Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
94
+
95
+ .. note::
96
+ Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
97
+
98
+ .. note::
99
+ SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
100
+
89
101
  Args:
90
102
  decay (float | None, optional): slowly decays preconditioners. Defaults to None.
91
103
  beta (float | None, optional):
92
104
  if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
93
- matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
94
105
  update_freq (int, optional): preconditioner update frequency. Defaults to 10.
95
- exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to None.
106
+ exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
96
107
  merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
97
108
  max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
98
109
  precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
@@ -101,32 +112,58 @@ class Shampoo(Transform):
101
112
  module applied after updating preconditioners and before applying preconditioning.
102
113
  For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
103
114
  Defaults to None.
115
+
116
+ Examples:
117
+ Shampoo grafted to Adam
118
+
119
+ .. code-block:: python
120
+
121
+ opt = tz.Modular(
122
+ model.parameters(),
123
+ tz.m.GraftModules(
124
+ direction = tz.m.Shampoo(),
125
+ magnitude = tz.m.Adam(),
126
+ ),
127
+ tz.m.LR(1e-3)
128
+ )
129
+
130
+ Adam with Shampoo preconditioner
131
+
132
+ .. code-block:: python
133
+
134
+ opt = tz.Modular(
135
+ model.parameters(),
136
+ tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
137
+ tz.m.Debias(0.9, 0.999),
138
+ tz.m.LR(1e-3)
139
+ )
104
140
  """
105
141
  def __init__(
106
142
  self,
107
143
  decay: float | None = None,
108
144
  beta: float | None = None,
145
+ reg: float = 1e-12,
109
146
  update_freq: int = 10,
110
- exp_override: int | None = None,
147
+ exp_override: int | None = 2,
111
148
  merge_small: bool = True,
112
149
  max_dim: int = 2_000,
113
150
  precondition_1d: bool = True,
114
151
  adagrad_eps: float = 1e-8,
115
152
  inner: Chainable | None = None,
116
153
  ):
117
- defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
154
+ defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps, reg=reg)
118
155
  super().__init__(defaults, uses_grad=False)
119
156
 
120
157
  if inner is not None:
121
158
  self.set_child('inner', inner)
122
159
 
123
- def apply(self, tensors, params, grads, loss, states, settings):
160
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
124
161
  merged_tensors = [] # target with merged dims
125
162
 
126
163
  # update preconditioners
127
164
  for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
128
- beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
129
- 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
165
+ beta, update_freq, exp_override, merge_small, max_dim, precondition_1d, reg = itemgetter(
166
+ 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d', "reg")(setting)
130
167
 
131
168
  if merge_small:
132
169
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -161,6 +198,7 @@ class Shampoo(Transform):
161
198
  update_freq=update_freq,
162
199
  exp_override=exp_override,
163
200
  beta=beta,
201
+ reg=reg,
164
202
  )
165
203
 
166
204
  # inner step
@@ -1,9 +1,10 @@
1
1
  from operator import itemgetter
2
+ import warnings
2
3
 
3
4
  import torch
4
5
 
5
6
  from ...core import Chainable, Transform, apply_transform
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+ from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
7
8
 
8
9
  @torch.no_grad
9
10
  def update_soap_covariances_(
@@ -24,11 +25,9 @@ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
24
25
  Projects the gradient to the eigenbases of the preconditioner.
25
26
  """
26
27
  for mat in Q:
27
- if mat is None: continue
28
- if len(mat) > 0:
28
+ if mat is not None and len(mat) > 0:
29
29
  tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
30
30
  else:
31
- # I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
32
31
  permute_order = list(range(1, len(tensors.shape))) + [0]
33
32
  tensors = tensors.permute(permute_order)
34
33
 
@@ -40,8 +39,7 @@ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
40
39
  Projects the gradient back to the original space.
41
40
  """
42
41
  for mat in Q:
43
- if mat is None: continue
44
- if len(mat) > 0:
42
+ if mat is not None and len(mat) > 0:
45
43
  tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
46
44
  else:
47
45
  permute_order = list(range(1, len(tensors.shape))) + [0]
@@ -55,37 +53,23 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
55
53
  """
56
54
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
57
55
  """
58
- matrix = []
59
- float_data = False
60
- original_type = original_device = None
61
- for m in mat:
62
- if m is None: continue
63
- if len(m) == 0:
64
- matrix.append([])
65
- continue
66
- if m.dtype != torch.float:
67
- original_type = m.dtype
68
- original_device = m.device
69
- matrix.append(m.float())
70
- else:
71
- float_data = True
72
- matrix.append(m)
73
56
 
74
57
  final = []
75
- for m in matrix:
76
- if len(m) == 0:
58
+ for m in mat:
59
+
60
+ if m is None or len(m) == 0:
77
61
  final.append([])
78
62
  continue
63
+
79
64
  try:
80
65
  _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
81
- except Exception:
66
+ except torch.linalg.LinAlgError:
82
67
  _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
83
68
  Q = Q.to(m.dtype)
84
- Q = torch.flip(Q, [1])
85
69
 
86
- if not float_data:
87
- Q = Q.to(original_device).type(original_type)
70
+ Q = torch.flip(Q, [1])
88
71
  final.append(Q)
72
+
89
73
  return final
90
74
 
91
75
  # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
@@ -95,42 +79,24 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
95
79
  Computes the eigenbases of the preconditioner using one round of power iteration
96
80
  followed by torch.linalg.qr decomposition.
97
81
  """
98
- matrix = []
99
- orth_matrix = []
100
- float_data = False
101
- original_type = original_device = None
102
- for m,o in zip(GG, Q_list):
103
- if m is None: continue
104
- assert o is not None
82
+ final = []
105
83
 
106
- if len(m) == 0:
107
- matrix.append([])
108
- orth_matrix.append([])
109
- continue
110
- if m.data.dtype != torch.float:
111
- original_type = m.data.dtype
112
- original_device = m.data.device
113
- matrix.append(m.data.float())
114
- orth_matrix.append(o.data.float())
115
- else:
116
- float_data = True
117
- matrix.append(m.data.float())
118
- orth_matrix.append(o.data.float())
84
+ for ind, (m,o) in enumerate(zip(GG, Q_list)):
119
85
 
120
- final = []
121
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
122
- if len(m)==0:
86
+ # skip 1d or large dims
87
+ if m is None or len(m) == 0:
123
88
  final.append([])
124
89
  continue
90
+ assert o is not None
91
+
125
92
  est_eig = torch.diag(o.T @ m @ o)
126
93
  sort_idx = torch.argsort(est_eig, descending=True)
127
94
  exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
128
- o = o[:,sort_idx]
129
- power_iter = m @ o
130
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
131
95
 
132
- if not float_data:
133
- Q = Q.to(original_device).type(original_type)
96
+ power_iter = m @ o[:, sort_idx]
97
+ Q, _ = torch.linalg.qr(power_iter.to(torch.float32)) # pylint:disable=not-callable
98
+ Q = Q.to(power_iter.dtype)
99
+
134
100
  final.append(Q)
135
101
 
136
102
  return final, exp_avg_sq
@@ -156,6 +122,24 @@ class SOAP(Transform):
156
122
  learning rate. Defaults to 1.
157
123
  bias_correction (bool, optional):
158
124
  enables adam bias correction. Defaults to True.
125
+
126
+ Examples:
127
+ SOAP:
128
+
129
+ .. code-block:: python
130
+
131
+ opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))
132
+
133
+ Stabilized SOAP:
134
+
135
+ .. code-block:: python
136
+
137
+ opt = tz.Modular(
138
+ model.parameters(),
139
+ tz.m.SOAP(),
140
+ tz.m.NormalizeByEMA(max_ema_growth=1.2),
141
+ tz.m.LR(1e-2)
142
+ )
159
143
  """
160
144
  def __init__(
161
145
  self,
@@ -187,7 +171,7 @@ class SOAP(Transform):
187
171
  super().__init__(defaults, uses_grad=False)
188
172
 
189
173
  @torch.no_grad
190
- def apply(self, tensors, params, grads, loss, states, settings):
174
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
191
175
  updates = []
192
176
  # update preconditioners
193
177
  for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
@@ -200,7 +184,7 @@ class SOAP(Transform):
200
184
  # initialize state on 1st step
201
185
  if 'GG' not in state:
202
186
  state["exp_avg"] = torch.zeros_like(t)
203
- state["exp_avg_sq"] = torch.zeros_like(t)
187
+ state["exp_avg_sq_projected"] = torch.zeros_like(t)
204
188
 
205
189
  if not precondition_1d and t.ndim <= 1:
206
190
  state['GG'] = []
@@ -214,7 +198,10 @@ class SOAP(Transform):
214
198
 
215
199
  if state['GG'] is not None:
216
200
  update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
217
- state['Q'] = get_orthogonal_matrix(state['GG'])
201
+ try: state['Q'] = get_orthogonal_matrix(state['GG'])
202
+ except torch.linalg.LinAlgError as e:
203
+ warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
204
+ state["GG"] = None
218
205
 
219
206
  state['step'] = 0
220
207
  updates.append(tensors[i].clip(-0.1, 0.1))
@@ -230,22 +217,20 @@ class SOAP(Transform):
230
217
  # exponential moving averages
231
218
  # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
232
219
  exp_avg: torch.Tensor = state["exp_avg"]
233
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
220
+ exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
234
221
 
235
222
  exp_avg.lerp_(t, 1-beta1)
236
223
 
237
224
  if t_projected is None:
238
- exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
225
+ exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
239
226
  else:
240
- exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
227
+ exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
241
228
 
242
229
  # project exponential moving averages if they are accumulated unprojected
243
230
  exp_avg_projected = exp_avg
244
231
  if t_projected is not None:
245
232
  exp_avg_projected = project(exp_avg, state['Q'])
246
233
 
247
- exp_avg_sq_projected = exp_avg_sq
248
-
249
234
  denom = exp_avg_sq_projected.sqrt().add_(eps)
250
235
  # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
251
236
 
@@ -273,6 +258,8 @@ class SOAP(Transform):
273
258
  if state['GG'] is not None:
274
259
  update_soap_covariances_(t, state['GG'], shampoo_beta)
275
260
  if state['step'] % setting['precond_freq'] == 0:
276
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
277
-
261
+ try:
262
+ state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
263
+ except torch.linalg.LinAlgError:
264
+ pass
278
265
  return updates