torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Module, Target, Transform, apply_transform
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
7
+ from ..functional import ema_
8
+ from ..momentum.momentum import nag_
9
+
10
+
11
+ def msam_(
12
+ tensors: TensorList,
13
+ params: TensorList,
14
+ velocity_: TensorList,
15
+ momentum: float | NumberList,
16
+ lr: NumberList | None,
17
+ rho: float | NumberList,
18
+ weight_decay: float | NumberList,
19
+ nesterov: bool = False,
20
+ lerp: bool = False,
21
+
22
+ # inner args
23
+ inner: Module | None = None,
24
+ grads: list[torch.Tensor] | None = None,
25
+ ):
26
+ # weights w and wh, momentum μ, perturbation strength ρ
27
+ # w = wh + rho * v / ||v||
28
+ # v1 = μv + g
29
+ # w1 = w - lr*v1
30
+ # wh1 = w1 - rho * v1 / ||v1||
31
+
32
+ # w1 = wh + rho * v / ||v|| - lr*v1
33
+ # vn = rho * v / ||v||
34
+ # v1n = rho * v1 / ||v1||
35
+ # wh1 = wh + vn - lr*v1 - v1n
36
+
37
+ # the update is
38
+ # vn - lr*v1 - v1n
39
+
40
+ # we track ascent direction so it becomes lr*v1 + v1n - vn
41
+
42
+ # can't really decouple it from lr
43
+ # but at least it is now expressed as function of g
44
+
45
+ denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
46
+ vn = velocity_ / denom
47
+
48
+ mom_ = nag_ if nesterov else ema_
49
+ velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
50
+
51
+ denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
52
+ v1n = velocity_ / denom
53
+
54
+ if inner is not None:
55
+ assert params is not None
56
+ inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
57
+
58
+ else:
59
+ assert lr is not None
60
+ inner_update = velocity_ * lr
61
+
62
+ update = inner_update.add_(v1n).sub_(vn)
63
+
64
+ if generic_ne(weight_decay, 0):
65
+ wd = (params + vn).mul_(weight_decay)
66
+ update.add_(wd)
67
+
68
+ return update
69
+
70
+ class MSAM(Transform):
71
+ """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
72
+
73
+ This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
74
+ replacement for momentum strategies in other optimizers.
75
+
76
+ To combine MSAM with other optimizers in the way done in the official implementation,
77
+ e.g. to make Adam_MSAM, use :code:`tz.m.MSAMObjective` module.
78
+
79
+ .. note::
80
+ MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
81
+ To avoid compounding learning rate mofications, remove the :code:`tz.m.LR` module if you had it.
82
+
83
+ Args:
84
+ lr (float): learning rate. Adding this module adds support for learning rate schedulers.
85
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
86
+ rho (float, optional): perturbation strength. Defaults to 0.3.
87
+ weight_decay (float, optional):
88
+ weight decay. It is applied to perturbed parameters, so it is differnet
89
+ from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
90
+ nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
91
+ lerp (bool, optional):
92
+ whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
93
+
94
+ Examples:
95
+ MSAM
96
+
97
+ .. code-block:: python
98
+
99
+ opt = tz.Modular(
100
+ model.parameters(),
101
+ tz.m.MSAM(1e-3)
102
+ )
103
+
104
+ Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
105
+ To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
106
+
107
+ .. code-block:: python
108
+
109
+ opt = tz.Modular(
110
+ model.parameters(),
111
+ tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
112
+ tz.m.Debias(0.9, 0.999),
113
+ )
114
+ """
115
+ USES_LR = True
116
+ def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
117
+ defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
118
+ if self.USES_LR: defaults['lr'] = lr
119
+ super().__init__(defaults, uses_grad=False)
120
+
121
+ @torch.no_grad
122
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
123
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
124
+ s = self.settings[params[0]]
125
+ lerp = s['lerp']
126
+ nesterov = s['nesterov']
127
+
128
+ if self.USES_LR:
129
+ lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
130
+
131
+ else:
132
+ lr=None
133
+ momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
134
+
135
+ return msam_(
136
+ TensorList(tensors),
137
+ params=TensorList(params),
138
+ velocity_=velocity,
139
+ momentum=momentum,
140
+ lr=lr,
141
+ rho=rho,
142
+ weight_decay=weight_decay,
143
+ nesterov=nesterov,
144
+ lerp=lerp,
145
+
146
+ # inner args
147
+ inner=self.children.get("modules", None),
148
+ grads=grads,
149
+ )
150
+
151
+
152
+ class MSAMObjective(MSAM):
153
+ """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
154
+
155
+ .. note::
156
+ Please make sure to place :code:`tz.m.LR` inside the :code:`modules` argument. For example,
157
+ :code:`tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])`. Putting LR after MSAM will lead
158
+ to an incorrect update rule.
159
+
160
+ Args:
161
+ modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
162
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
163
+ rho (float, optional): perturbation strength. Defaults to 0.3.
164
+ nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
165
+ lerp (bool, optional):
166
+ whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
167
+ Defaults to False.
168
+
169
+ Examples:
170
+ AdamW-MSAM
171
+
172
+ .. code-block:: python
173
+
174
+ opt = tz.Modular(
175
+ bench.parameters(),
176
+ tz.m.MSAMObjective(
177
+ [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
178
+ rho=1.
179
+ )
180
+ )
181
+ """
182
+ USES_LR = False
183
+ def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
184
+ super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
185
+ self.set_child('modules', modules)
186
+
@@ -19,6 +19,7 @@ def _is_at_least_2d(p: torch.Tensor):
19
19
 
20
20
  # stolen from:
21
21
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
22
+ # actually at this stage its a frankenstein
22
23
  @enable_compilation
23
24
  def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
24
25
  """
@@ -152,7 +153,7 @@ class Orthogonalize(TensorwiseTransform):
152
153
  The Muon page says that embeddings and classifier heads should not be orthogonalized.
153
154
  Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
154
155
 
155
- To make Muon, use Split with Adam on 1d params: TODO code example.
156
+ To make Muon, use Split with Adam on 1d params
156
157
 
157
158
  Args:
158
159
  ns_steps (int, optional):
@@ -165,6 +166,30 @@ class Orthogonalize(TensorwiseTransform):
165
166
  Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
166
167
  target (str, optional):
167
168
  what to set on var.
169
+
170
+
171
+ Examples:
172
+ standard Muon with Adam fallback
173
+
174
+ .. code-block:: python
175
+
176
+ opt = tz.Modular(
177
+ model.head.parameters(),
178
+ tz.m.Split(
179
+ # apply muon only to 2D+ parameters
180
+ filter = lambda t: t.ndim >= 2,
181
+ true = [
182
+ tz.m.HeavyBall(),
183
+ tz.m.Orthogonalize(),
184
+ tz.m.LR(1e-2),
185
+ ],
186
+ false = tz.m.Adam()
187
+ ),
188
+ tz.m.LR(1e-2)
189
+ )
190
+
191
+ Reference:
192
+ Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
168
193
  """
169
194
  def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
170
195
  method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
@@ -172,9 +197,9 @@ class Orthogonalize(TensorwiseTransform):
172
197
  super().__init__(uses_grad=False, defaults=defaults, target=target)
173
198
 
174
199
  @torch.no_grad
175
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
200
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
176
201
  orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
177
- 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(settings)
202
+ 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
178
203
 
179
204
  if not orthogonalize: return tensor
180
205
 
@@ -199,7 +224,7 @@ class DualNormCorrection(TensorwiseTransform):
199
224
  def __init__(self, target: Target='update'):
200
225
  super().__init__({}, uses_grad=True, target=target)
201
226
 
202
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
227
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
203
228
  assert grad is not None
204
229
  if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
205
230
  return _dual_norm_correction(tensor, grad, batch_first=False)
@@ -213,7 +238,7 @@ class MuonAdjustLR(Transform):
213
238
  defaults = dict(alpha=alpha)
214
239
  super().__init__(defaults=defaults, uses_grad=False, target=target)
215
240
 
216
- def apply(self, tensors, params, grads, loss, states, settings):
241
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
217
242
  alphas = [s['alpha'] for s in settings]
218
243
  tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
219
244
  tensors = [i[0] for i in tensors_alphas]
@@ -36,7 +36,7 @@ class OrthoGrad(Transform):
36
36
  defaults = dict(eps=eps, renormalize=renormalize)
37
37
  super().__init__(defaults, uses_grad=False, target=target)
38
38
 
39
- def apply(self, tensors, params, grads, loss, states, settings):
39
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
40
  eps = settings[0]['eps']
41
41
  renormalize = settings[0]['renormalize']
42
42
 
@@ -40,7 +40,9 @@ def rmsprop_(
40
40
  return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
41
41
 
42
42
  class RMSprop(Transform):
43
- """Divides graient by EMA of gradient squares. Matches pytorch RMSprop if "init" is set to "zeros".
43
+ """Divides graient by EMA of gradient squares.
44
+
45
+ This implementation is identical to :code:`torch.optim.RMSprop`.
44
46
 
45
47
  Args:
46
48
  smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
@@ -50,7 +52,8 @@ class RMSprop(Transform):
50
52
  amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
51
53
  pow (float, optional): power used in second momentum power and root. Defaults to 2.
52
54
  init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
53
- inner (Chainable | None, optional): Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
55
+ inner (Chainable | None, optional):
56
+ Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
54
57
  """
55
58
  def __init__(
56
59
  self,
@@ -60,7 +63,7 @@ class RMSprop(Transform):
60
63
  debiased: bool = False,
61
64
  amsgrad: bool = False,
62
65
  pow: float = 2,
63
- init: Literal["zeros", "update"] = "update",
66
+ init: Literal["zeros", "update"] = "zeros",
64
67
  inner: Chainable | None = None,
65
68
  ):
66
69
  defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
@@ -69,7 +72,7 @@ class RMSprop(Transform):
69
72
  if inner is not None:
70
73
  self.set_child('inner', inner)
71
74
 
72
- def apply(self, tensors, params, grads, loss, states, settings):
75
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
73
76
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
74
77
  smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
75
78
  centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
@@ -135,7 +135,8 @@ class Rprop(Transform):
135
135
  Next step, magnitude for that weight won't change.
136
136
 
137
137
  Compared to pytorch this also implements backtracking update when sign changes.
138
- To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
138
+
139
+ This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.
139
140
 
140
141
  Args:
141
142
  nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
@@ -164,7 +165,7 @@ class Rprop(Transform):
164
165
  super().__init__(defaults, uses_grad=False)
165
166
 
166
167
  @torch.no_grad
167
- def apply(self, tensors, params, grads, loss, states, settings):
168
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
168
169
  step = self.global_state.get('step', 0)
169
170
  self.global_state['step'] = step + 1
170
171
 
@@ -223,7 +224,7 @@ class ScaleLRBySignChange(Transform):
223
224
  super().__init__(defaults, uses_grad=use_grad, target=target)
224
225
 
225
226
  @torch.no_grad
226
- def apply(self, tensors, params, grads, loss, states, settings):
227
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
227
228
  step = self.global_state.get('step', 0)
228
229
  self.global_state['step'] = step + 1
229
230
 
@@ -272,7 +273,7 @@ class BacktrackOnSignChange(Transform):
272
273
  super().__init__(defaults, uses_grad=use_grad)
273
274
 
274
275
  @torch.no_grad
275
- def apply(self, tensors, params, grads, loss, states, settings):
276
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
276
277
  step = self.global_state.get('step', 0)
277
278
  self.global_state['step'] = step + 1
278
279
 
@@ -294,12 +295,29 @@ class BacktrackOnSignChange(Transform):
294
295
  return tensors
295
296
 
296
297
  class SignConsistencyMask(Transform):
297
- """0 if sign changed 1 otherwise"""
298
+ """
299
+ Outputs a mask of sign consistency of current and previous inputs.
300
+
301
+ The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
302
+
303
+ Examples:
304
+
305
+ GD that skips update for weights where gradient sign changed compared to previous gradient.
306
+
307
+ .. code-block:: python
308
+
309
+ opt = tz.Modular(
310
+ model.parameters(),
311
+ tz.m.Mul(tz.m.SignConsistencyMask()),
312
+ tz.m.LR(1e-2)
313
+ )
314
+
315
+ """
298
316
  def __init__(self,target: Target = 'update'):
299
317
  super().__init__({}, uses_grad=False, target = target)
300
318
 
301
319
  @torch.no_grad
302
- def apply(self, tensors, params, grads, loss, states, settings):
320
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
303
321
  prev = unpack_states(states, tensors, 'prev', cls=TensorList)
304
322
  mask = prev.mul_(tensors).gt_(0)
305
323
  prev.copy_(tensors)
@@ -307,7 +325,23 @@ class SignConsistencyMask(Transform):
307
325
 
308
326
 
309
327
  class SignConsistencyLRs(Transform):
310
- """LR for each weight is increased when two consequtive update signs are the same, decreased otherwise. This returns the LRs themselves."""
328
+ """Outputs per-weight learning rates based on consecutive sign consistency.
329
+
330
+ The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
331
+
332
+ Examples:
333
+
334
+ GD scaled by consecutive gradient sign consistency
335
+
336
+ .. code-block:: python
337
+
338
+ opt = tz.Modular(
339
+ model.parameters(),
340
+ tz.m.Mul(tz.m.SignConsistencyLRs()),
341
+ tz.m.LR(1e-2)
342
+ )
343
+
344
+ """
311
345
  def __init__(
312
346
  self,
313
347
  nplus: float = 1.2,
@@ -321,7 +355,7 @@ class SignConsistencyLRs(Transform):
321
355
  super().__init__(defaults, uses_grad=False, target = target)
322
356
 
323
357
  @torch.no_grad
324
- def apply(self, tensors, params, grads, loss, states, settings):
358
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
325
359
  step = self.global_state.get('step', 0)
326
360
  self.global_state['step'] = step + 1
327
361
 
@@ -0,0 +1,163 @@
1
+ from contextlib import nullcontext
2
+ import torch
3
+ from ...utils import TensorList, NumberList
4
+ from ...core import Module
5
+
6
+
7
+ class SAM(Module):
8
+ """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
9
+
10
+ SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
11
+ It performs two forward and backward passes per step.
12
+
13
+ This implementation modifies the closure to return loss and calculate gradients
14
+ of the SAM objective. All modules after this will use the modified objective.
15
+
16
+ .. note::
17
+ This module requires a closure passed to the optimizer step,
18
+ as it needs to re-evaluate the loss and gradients at two points on each step.
19
+
20
+ Args:
21
+ rho (float, optional): Neighborhood size. Defaults to 0.05.
22
+ p (float, optional): norm of the SAM objective. Defaults to 2.
23
+ asam (bool, optional):
24
+ enables ASAM variant which makes perturbation relative to weight magnitudes.
25
+ ASAM requires a much larger :code:`rho`, like 0.5 or 1.
26
+ The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
27
+ it has larger :code:`rho` by default.
28
+
29
+ Examples:
30
+ SAM-SGD:
31
+
32
+ .. code-block:: python
33
+
34
+ opt = tz.Modular(
35
+ model.parameters(),
36
+ tz.m.SAM(),
37
+ tz.m.LR(1e-2)
38
+ )
39
+
40
+ SAM-Adam:
41
+
42
+ .. code-block:: python
43
+
44
+ opt = tz.Modular(
45
+ model.parameters(),
46
+ tz.m.SAM(),
47
+ tz.m.Adam(),
48
+ tz.m.LR(1e-2)
49
+ )
50
+
51
+ References:
52
+ Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
53
+ """
54
+ def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
55
+ defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
56
+ super().__init__(defaults)
57
+
58
+ @torch.no_grad
59
+ def step(self, var):
60
+
61
+ params = var.params
62
+ closure = var.closure
63
+ zero_grad = var.zero_grad
64
+ if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
65
+ p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
66
+ s = self.settings[var.params[0]]
67
+ eps = s['eps']
68
+ asam = s['asam']
69
+
70
+ # 1/p + 1/q = 1
71
+ # okay, authors of SAM paper, I will manually solve your equation
72
+ # so q = -p/(1-p)
73
+ q = -p / (1-p)
74
+ # as a validation for 2 it is -2 / -1 = 2
75
+
76
+ @torch.no_grad
77
+ def sam_closure(backward=True):
78
+ orig_grads = None
79
+ if not backward:
80
+ # if backward is False, make sure this doesn't modify gradients
81
+ # to avoid issues
82
+ orig_grads = [p.grad for p in params]
83
+
84
+ # gradient at initial parameters
85
+ zero_grad()
86
+ with torch.enable_grad():
87
+ closure()
88
+
89
+ grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
90
+ grad_abs = grad.abs()
91
+
92
+ # compute e
93
+ term1 = grad.sign().mul_(rho)
94
+ term2 = grad_abs.pow(q-1)
95
+
96
+ if asam:
97
+ grad_abs.mul_(torch._foreach_abs(params))
98
+
99
+ denom = grad_abs.pow_(q).sum().pow(1/p)
100
+
101
+ e = term1.mul_(term2).div_(denom.clip(min=eps))
102
+
103
+ if asam:
104
+ e.mul_(torch._foreach_pow(params, 2))
105
+
106
+ # calculate loss and gradient approximation of inner problem
107
+ torch._foreach_add_(params, e)
108
+ if backward:
109
+ zero_grad()
110
+ with torch.enable_grad():
111
+ # this sets .grad attributes
112
+ sam_loss = closure()
113
+
114
+ else:
115
+ sam_loss = closure(False)
116
+
117
+ # and restore initial parameters
118
+ torch._foreach_sub_(params, e)
119
+
120
+ if orig_grads is not None:
121
+ for param,orig_grad in zip(params, orig_grads):
122
+ param.grad = orig_grad
123
+
124
+ return sam_loss
125
+
126
+ var.closure = sam_closure
127
+ return var
128
+
129
+ # different class because defaults for SAM are bad for ASAM
130
+ class ASAM(SAM):
131
+ """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
132
+
133
+ SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
134
+ It performs two forward and backward passes per step.
135
+
136
+ This implementation modifies the closure to return loss and calculate gradients
137
+ of the SAM objective. All modules after this will use the modified objective.
138
+
139
+ .. note::
140
+ This module requires a closure passed to the optimizer step,
141
+ as it needs to re-evaluate the loss and gradients at two points on each step.
142
+
143
+ Args:
144
+ rho (float, optional): Neighborhood size. Defaults to 0.05.
145
+ p (float, optional): norm of the SAM objective. Defaults to 2.
146
+
147
+ Examples:
148
+ ASAM-Adam:
149
+
150
+ .. code-block:: python
151
+
152
+ opt = tz.Modular(
153
+ model.parameters(),
154
+ tz.m.ASAM(),
155
+ tz.m.Adam(),
156
+ tz.m.LR(1e-2)
157
+ )
158
+
159
+ References:
160
+ Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
161
+ """
162
+ def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
163
+ super().__init__(rho=rho, p=p, eps=eps, asam=True)
@@ -59,7 +59,7 @@ def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
59
59
  if tensor.shape[sort_idxs[0]] > max_dim:
60
60
  return tensor, None, None
61
61
 
62
- tensor = tensor.permute(*sort_idxs)
62
+ tensor = tensor.permute(*sort_idxs.tolist())
63
63
  flatten_end_idx = 0
64
64
  flat_sizes = []
65
65
  flat_numel = 1
@@ -80,19 +80,28 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
80
80
  if flat_sizes is None: return tensor
81
81
  assert sort_idxs is not None
82
82
  tensor = tensor.unflatten(0, flat_sizes)
83
- return tensor.permute(*np.argsort(sort_idxs))
83
+ return tensor.permute(*np.argsort(sort_idxs).tolist())
84
84
 
85
85
 
86
86
  class Shampoo(Transform):
87
87
  """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
88
88
 
89
+ .. note::
90
+ Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
91
+
92
+ .. note::
93
+ Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
94
+
95
+ .. note::
96
+ SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
97
+
89
98
  Args:
90
99
  decay (float | None, optional): slowly decays preconditioners. Defaults to None.
91
100
  beta (float | None, optional):
92
101
  if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
93
102
  matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
94
103
  update_freq (int, optional): preconditioner update frequency. Defaults to 10.
95
- exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to None.
104
+ exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
96
105
  merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
97
106
  max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
98
107
  precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
@@ -101,13 +110,38 @@ class Shampoo(Transform):
101
110
  module applied after updating preconditioners and before applying preconditioning.
102
111
  For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
103
112
  Defaults to None.
113
+
114
+ Examples:
115
+ Shampoo grafted to Adam
116
+
117
+ .. code-block:: python
118
+
119
+ opt = tz.Modular(
120
+ model.parameters(),
121
+ tz.m.GraftModules(
122
+ direction = tz.m.Shampoo(),
123
+ magnitude = tz.m.Adam(),
124
+ ),
125
+ tz.m.LR(1e-3)
126
+ )
127
+
128
+ Adam with Shampoo preconditioner
129
+
130
+ .. code-block:: python
131
+
132
+ opt = tz.Modular(
133
+ model.parameters(),
134
+ tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
135
+ tz.m.Debias(0.9, 0.999),
136
+ tz.m.LR(1e-3)
137
+ )
104
138
  """
105
139
  def __init__(
106
140
  self,
107
141
  decay: float | None = None,
108
142
  beta: float | None = None,
109
143
  update_freq: int = 10,
110
- exp_override: int | None = None,
144
+ exp_override: int | None = 2,
111
145
  merge_small: bool = True,
112
146
  max_dim: int = 2_000,
113
147
  precondition_1d: bool = True,
@@ -120,7 +154,7 @@ class Shampoo(Transform):
120
154
  if inner is not None:
121
155
  self.set_child('inner', inner)
122
156
 
123
- def apply(self, tensors, params, grads, loss, states, settings):
157
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
124
158
  merged_tensors = [] # target with merged dims
125
159
 
126
160
  # update preconditioners