torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
 
4
4
  from ...core import Module, Target, Transform
5
- from ...utils import NumberList, TensorList, as_tensorlist
5
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
6
6
 
7
7
 
8
8
  def _bool_ones_like(x):
@@ -135,7 +135,8 @@ class Rprop(Transform):
135
135
  Next step, magnitude for that weight won't change.
136
136
 
137
137
  Compared to pytorch this also implements backtracking update when sign changes.
138
- To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
138
+
139
+ This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.
139
140
 
140
141
  Args:
141
142
  nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
@@ -161,20 +162,22 @@ class Rprop(Transform):
161
162
  alpha: float = 1,
162
163
  ):
163
164
  defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
164
- self.current_step = 0
165
165
  super().__init__(defaults, uses_grad=False)
166
166
 
167
167
  @torch.no_grad
168
- def transform(self, tensors, params, grads, vars):
169
- nplus, nminus, lb, ub, alpha = self.get_settings('nplus', 'nminus', 'lb', 'ub', 'alpha', params=params, cls=NumberList)
170
- prev, allowed, magnitudes = self.get_state(
168
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
169
+ step = self.global_state.get('step', 0)
170
+ self.global_state['step'] = step + 1
171
+
172
+ nplus, nminus, lb, ub, alpha = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', 'alpha', cls=NumberList)
173
+ prev, allowed, magnitudes = unpack_states(
174
+ states, tensors,
171
175
  'prev','allowed','magnitudes',
172
- params=params,
173
176
  init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
174
177
  cls = TensorList,
175
178
  )
176
179
 
177
- target = rprop_(
180
+ tensors = rprop_(
178
181
  tensors_ = as_tensorlist(tensors),
179
182
  prev_ = prev,
180
183
  allowed_ = allowed,
@@ -184,12 +187,11 @@ class Rprop(Transform):
184
187
  lb = lb,
185
188
  ub = ub,
186
189
  alpha = alpha,
187
- backtrack=self.settings[params[0]]['backtrack'],
188
- step=self.current_step,
190
+ backtrack=settings[0]['backtrack'],
191
+ step=step,
189
192
  )
190
193
 
191
- self.current_step += 1
192
- return target
194
+ return tensors
193
195
 
194
196
 
195
197
  class ScaleLRBySignChange(Transform):
@@ -220,23 +222,25 @@ class ScaleLRBySignChange(Transform):
220
222
  ):
221
223
  defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
222
224
  super().__init__(defaults, uses_grad=use_grad, target=target)
223
- self.current_step = 0
224
225
 
225
226
  @torch.no_grad
226
- def transform(self, tensors, params, grads, vars):
227
- target = as_tensorlist(tensors)
228
- use_grad = self.settings[params[0]]['use_grad']
227
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
228
+ step = self.global_state.get('step', 0)
229
+ self.global_state['step'] = step + 1
230
+
231
+ tensors = as_tensorlist(tensors)
232
+ use_grad = settings[0]['use_grad']
229
233
  if use_grad: cur = as_tensorlist(grads)
230
- else: cur = target
234
+ else: cur = tensors
231
235
 
232
- nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
233
- prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
236
+ nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
237
+ prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
234
238
 
235
- if self.current_step == 0:
236
- lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
239
+ if step == 0:
240
+ lrs.set_(tensors.full_like([s['alpha'] for s in settings]))
237
241
 
238
- target = scale_by_sign_change_(
239
- tensors_ = target,
242
+ tensors = scale_by_sign_change_(
243
+ tensors_ = tensors,
240
244
  cur = cur,
241
245
  prev_ = prev,
242
246
  lrs_ = lrs,
@@ -244,10 +248,9 @@ class ScaleLRBySignChange(Transform):
244
248
  nminus = nminus,
245
249
  lb = lb,
246
250
  ub = ub,
247
- step = self.current_step,
251
+ step = step,
248
252
  )
249
- self.current_step += 1
250
- return target
253
+ return tensors
251
254
 
252
255
  class BacktrackOnSignChange(Transform):
253
256
  """Negates or undoes update for parameters where where gradient or update sign changes.
@@ -268,44 +271,77 @@ class BacktrackOnSignChange(Transform):
268
271
  def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
269
272
  defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
270
273
  super().__init__(defaults, uses_grad=use_grad)
271
- self.current_step = 0
272
274
 
273
275
  @torch.no_grad
274
- def transform(self, tensors, params, grads, vars):
275
- target = as_tensorlist(tensors)
276
- settings = self.settings[params[0]]
277
- use_grad = settings['use_grad']
278
- backtrack = settings['backtrack']
276
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
277
+ step = self.global_state.get('step', 0)
278
+ self.global_state['step'] = step + 1
279
+
280
+ tensors = as_tensorlist(tensors)
281
+ use_grad = settings[0]['use_grad']
282
+ backtrack = settings[0]['backtrack']
279
283
 
280
284
  if use_grad: cur = as_tensorlist(grads)
281
- else: cur = target
285
+ else: cur = tensors
282
286
 
283
- target = backtrack_on_sign_change_(
284
- tensors_ = target,
287
+ tensors = backtrack_on_sign_change_(
288
+ tensors_ = tensors,
285
289
  cur = cur,
286
- prev_ = self.get_state('prev', params=params, cls=TensorList),
290
+ prev_ = unpack_states(states, tensors, 'prev', cls=TensorList),
287
291
  backtrack = backtrack,
288
- step = self.current_step,
292
+ step = step,
289
293
  )
290
294
 
291
- self.current_step += 1
292
- return target
295
+ return tensors
293
296
 
294
297
  class SignConsistencyMask(Transform):
295
- """0 if sign changed 1 otherwise"""
298
+ """
299
+ Outputs a mask of sign consistency of current and previous inputs.
300
+
301
+ The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
302
+
303
+ Examples:
304
+
305
+ GD that skips update for weights where gradient sign changed compared to previous gradient.
306
+
307
+ .. code-block:: python
308
+
309
+ opt = tz.Modular(
310
+ model.parameters(),
311
+ tz.m.Mul(tz.m.SignConsistencyMask()),
312
+ tz.m.LR(1e-2)
313
+ )
314
+
315
+ """
296
316
  def __init__(self,target: Target = 'update'):
297
317
  super().__init__({}, uses_grad=False, target = target)
298
318
 
299
319
  @torch.no_grad
300
- def transform(self, tensors, params, grads, vars):
301
- prev = self.get_state('prev', params=params, cls=TensorList)
320
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
321
+ prev = unpack_states(states, tensors, 'prev', cls=TensorList)
302
322
  mask = prev.mul_(tensors).gt_(0)
303
- prev.set_(tensors)
323
+ prev.copy_(tensors)
304
324
  return mask
305
325
 
306
326
 
307
327
  class SignConsistencyLRs(Transform):
308
- """LR for each weight is increased when two consequtive update signs are the same, decreased otherwise. This returns the LRs themselves."""
328
+ """Outputs per-weight learning rates based on consecutive sign consistency.
329
+
330
+ The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
331
+
332
+ Examples:
333
+
334
+ GD scaled by consecutive gradient sign consistency
335
+
336
+ .. code-block:: python
337
+
338
+ opt = tz.Modular(
339
+ model.parameters(),
340
+ tz.m.Mul(tz.m.SignConsistencyLRs()),
341
+ tz.m.LR(1e-2)
342
+ )
343
+
344
+ """
309
345
  def __init__(
310
346
  self,
311
347
  nplus: float = 1.2,
@@ -317,16 +353,18 @@ class SignConsistencyLRs(Transform):
317
353
  ):
318
354
  defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
319
355
  super().__init__(defaults, uses_grad=False, target = target)
320
- self.current_step = 0
321
356
 
322
357
  @torch.no_grad
323
- def transform(self, tensors, params, grads, vars):
358
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
359
+ step = self.global_state.get('step', 0)
360
+ self.global_state['step'] = step + 1
361
+
324
362
  target = as_tensorlist(tensors)
325
- nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
326
- prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
363
+ nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
364
+ prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
327
365
 
328
- if self.current_step == 0:
329
- lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
366
+ if step == 0:
367
+ lrs.set_(target.full_like([s['alpha'] for s in settings]))
330
368
 
331
369
  target = sign_consistency_lrs_(
332
370
  tensors = target,
@@ -336,7 +374,6 @@ class SignConsistencyLRs(Transform):
336
374
  nminus = nminus,
337
375
  lb = lb,
338
376
  ub = ub,
339
- step = self.current_step,
377
+ step = step,
340
378
  )
341
- self.current_step += 1
342
379
  return target.clone()
@@ -0,0 +1,163 @@
1
+ from contextlib import nullcontext
2
+ import torch
3
+ from ...utils import TensorList, NumberList
4
+ from ...core import Module
5
+
6
+
7
+ class SAM(Module):
8
+ """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
9
+
10
+ SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
11
+ It performs two forward and backward passes per step.
12
+
13
+ This implementation modifies the closure to return loss and calculate gradients
14
+ of the SAM objective. All modules after this will use the modified objective.
15
+
16
+ .. note::
17
+ This module requires a closure passed to the optimizer step,
18
+ as it needs to re-evaluate the loss and gradients at two points on each step.
19
+
20
+ Args:
21
+ rho (float, optional): Neighborhood size. Defaults to 0.05.
22
+ p (float, optional): norm of the SAM objective. Defaults to 2.
23
+ asam (bool, optional):
24
+ enables ASAM variant which makes perturbation relative to weight magnitudes.
25
+ ASAM requires a much larger :code:`rho`, like 0.5 or 1.
26
+ The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
27
+ it has larger :code:`rho` by default.
28
+
29
+ Examples:
30
+ SAM-SGD:
31
+
32
+ .. code-block:: python
33
+
34
+ opt = tz.Modular(
35
+ model.parameters(),
36
+ tz.m.SAM(),
37
+ tz.m.LR(1e-2)
38
+ )
39
+
40
+ SAM-Adam:
41
+
42
+ .. code-block:: python
43
+
44
+ opt = tz.Modular(
45
+ model.parameters(),
46
+ tz.m.SAM(),
47
+ tz.m.Adam(),
48
+ tz.m.LR(1e-2)
49
+ )
50
+
51
+ References:
52
+ Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
53
+ """
54
+ def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
55
+ defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
56
+ super().__init__(defaults)
57
+
58
+ @torch.no_grad
59
+ def step(self, var):
60
+
61
+ params = var.params
62
+ closure = var.closure
63
+ zero_grad = var.zero_grad
64
+ if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
65
+ p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
66
+ s = self.settings[var.params[0]]
67
+ eps = s['eps']
68
+ asam = s['asam']
69
+
70
+ # 1/p + 1/q = 1
71
+ # okay, authors of SAM paper, I will manually solve your equation
72
+ # so q = -p/(1-p)
73
+ q = -p / (1-p)
74
+ # as a validation for 2 it is -2 / -1 = 2
75
+
76
+ @torch.no_grad
77
+ def sam_closure(backward=True):
78
+ orig_grads = None
79
+ if not backward:
80
+ # if backward is False, make sure this doesn't modify gradients
81
+ # to avoid issues
82
+ orig_grads = [p.grad for p in params]
83
+
84
+ # gradient at initial parameters
85
+ zero_grad()
86
+ with torch.enable_grad():
87
+ closure()
88
+
89
+ grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
90
+ grad_abs = grad.abs()
91
+
92
+ # compute e
93
+ term1 = grad.sign().mul_(rho)
94
+ term2 = grad_abs.pow(q-1)
95
+
96
+ if asam:
97
+ grad_abs.mul_(torch._foreach_abs(params))
98
+
99
+ denom = grad_abs.pow_(q).sum().pow(1/p)
100
+
101
+ e = term1.mul_(term2).div_(denom.clip(min=eps))
102
+
103
+ if asam:
104
+ e.mul_(torch._foreach_pow(params, 2))
105
+
106
+ # calculate loss and gradient approximation of inner problem
107
+ torch._foreach_add_(params, e)
108
+ if backward:
109
+ zero_grad()
110
+ with torch.enable_grad():
111
+ # this sets .grad attributes
112
+ sam_loss = closure()
113
+
114
+ else:
115
+ sam_loss = closure(False)
116
+
117
+ # and restore initial parameters
118
+ torch._foreach_sub_(params, e)
119
+
120
+ if orig_grads is not None:
121
+ for param,orig_grad in zip(params, orig_grads):
122
+ param.grad = orig_grad
123
+
124
+ return sam_loss
125
+
126
+ var.closure = sam_closure
127
+ return var
128
+
129
+ # different class because defaults for SAM are bad for ASAM
130
+ class ASAM(SAM):
131
+ """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
132
+
133
+ SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
134
+ It performs two forward and backward passes per step.
135
+
136
+ This implementation modifies the closure to return loss and calculate gradients
137
+ of the SAM objective. All modules after this will use the modified objective.
138
+
139
+ .. note::
140
+ This module requires a closure passed to the optimizer step,
141
+ as it needs to re-evaluate the loss and gradients at two points on each step.
142
+
143
+ Args:
144
+ rho (float, optional): Neighborhood size. Defaults to 0.05.
145
+ p (float, optional): norm of the SAM objective. Defaults to 2.
146
+
147
+ Examples:
148
+ ASAM-Adam:
149
+
150
+ .. code-block:: python
151
+
152
+ opt = tz.Modular(
153
+ model.parameters(),
154
+ tz.m.ASAM(),
155
+ tz.m.Adam(),
156
+ tz.m.LR(1e-2)
157
+ )
158
+
159
+ References:
160
+ Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
161
+ """
162
+ def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
163
+ super().__init__(rho=rho, p=p, eps=eps, asam=True)
@@ -4,7 +4,7 @@ from functools import partial
4
4
  import numpy as np
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Transform, apply
7
+ from ...core import Chainable, Transform, apply_transform
8
8
  from ...utils.linalg import matrix_power_eigh
9
9
  from ...utils import set_storage_
10
10
 
@@ -59,7 +59,7 @@ def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
59
59
  if tensor.shape[sort_idxs[0]] > max_dim:
60
60
  return tensor, None, None
61
61
 
62
- tensor = tensor.permute(*sort_idxs)
62
+ tensor = tensor.permute(*sort_idxs.tolist())
63
63
  flatten_end_idx = 0
64
64
  flat_sizes = []
65
65
  flat_numel = 1
@@ -80,19 +80,28 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
80
80
  if flat_sizes is None: return tensor
81
81
  assert sort_idxs is not None
82
82
  tensor = tensor.unflatten(0, flat_sizes)
83
- return tensor.permute(*np.argsort(sort_idxs))
83
+ return tensor.permute(*np.argsort(sort_idxs).tolist())
84
84
 
85
85
 
86
86
  class Shampoo(Transform):
87
87
  """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
88
88
 
89
+ .. note::
90
+ Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
91
+
92
+ .. note::
93
+ Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
94
+
95
+ .. note::
96
+ SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
97
+
89
98
  Args:
90
99
  decay (float | None, optional): slowly decays preconditioners. Defaults to None.
91
100
  beta (float | None, optional):
92
101
  if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
93
102
  matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
94
103
  update_freq (int, optional): preconditioner update frequency. Defaults to 10.
95
- exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to None.
104
+ exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
96
105
  merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
97
106
  max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
98
107
  precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
@@ -101,39 +110,62 @@ class Shampoo(Transform):
101
110
  module applied after updating preconditioners and before applying preconditioning.
102
111
  For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
103
112
  Defaults to None.
113
+
114
+ Examples:
115
+ Shampoo grafted to Adam
116
+
117
+ .. code-block:: python
118
+
119
+ opt = tz.Modular(
120
+ model.parameters(),
121
+ tz.m.GraftModules(
122
+ direction = tz.m.Shampoo(),
123
+ magnitude = tz.m.Adam(),
124
+ ),
125
+ tz.m.LR(1e-3)
126
+ )
127
+
128
+ Adam with Shampoo preconditioner
129
+
130
+ .. code-block:: python
131
+
132
+ opt = tz.Modular(
133
+ model.parameters(),
134
+ tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
135
+ tz.m.Debias(0.9, 0.999),
136
+ tz.m.LR(1e-3)
137
+ )
104
138
  """
105
139
  def __init__(
106
140
  self,
107
141
  decay: float | None = None,
108
142
  beta: float | None = None,
109
- reg: float = 1e-6,
110
143
  update_freq: int = 10,
111
- exp_override: int | None = None,
144
+ exp_override: int | None = 2,
112
145
  merge_small: bool = True,
113
146
  max_dim: int = 2_000,
114
147
  precondition_1d: bool = True,
115
148
  adagrad_eps: float = 1e-8,
116
149
  inner: Chainable | None = None,
117
150
  ):
118
- defaults = dict(decay=decay, beta=beta, reg=reg, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
151
+ defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
119
152
  super().__init__(defaults, uses_grad=False)
120
153
 
121
154
  if inner is not None:
122
155
  self.set_child('inner', inner)
123
156
 
124
- def transform(self, tensors, params, grads, vars):
125
- merged_target = [] # target with merged dims
157
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
158
+ merged_tensors = [] # target with merged dims
126
159
 
127
160
  # update preconditioners
128
- for i,(p,t) in enumerate(zip(params, tensors)):
129
- state = self.state[p]
130
- settings = self.settings[p]
131
- beta, reg, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
132
- 'beta', 'reg', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(settings)
161
+ for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
162
+ beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
163
+ 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
133
164
 
134
165
  if merge_small:
135
166
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
136
- merged_target.append(t)
167
+
168
+ merged_tensors.append(t)
137
169
 
138
170
  # initialize accumulators and preconditioners for each dim on 1st step
139
171
  if 'accumulators' not in state:
@@ -167,22 +199,18 @@ class Shampoo(Transform):
167
199
 
168
200
  # inner step
169
201
  if 'inner' in self.children:
170
- tensors = apply(self.children['inner'], tensors, params=params, grads=grads, vars=vars)
202
+ tensors = apply_transform(self.children['inner'], tensors, params=params, grads=grads)
171
203
 
172
204
  # have to merge small dims again
173
- merged_target = [] # target with merged dims
174
- for i,(p,t) in enumerate(zip(params, tensors)):
175
- state = self.state[p]
176
- settings = self.settings[p]
177
- if settings['merge_small']:
178
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, settings['max_dim'])
179
- merged_target.append(t)
205
+ merged_tensors = [] # target with merged dims
206
+ for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
207
+ if setting['merge_small']:
208
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, setting['max_dim'])
209
+ merged_tensors.append(t)
180
210
 
181
211
  # precondition
182
- for i, (p, t) in enumerate(zip(params, merged_target)):
183
- state = self.state[p]
184
- settings = self.settings[p]
185
- decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(settings)
212
+ for i,(t,state, setting) in enumerate(zip(merged_tensors, states, settings)):
213
+ decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(setting)
186
214
 
187
215
  if 'diagonal_accumulator' in state:
188
216
  tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)