torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_opts.py CHANGED
@@ -54,7 +54,7 @@ class _TestModel(torch.nn.Module):
54
54
  def forward(self):
55
55
  return torch.sum(torch.stack([p.pow(2).sum() for p in self.params]))
56
56
 
57
- def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, steps: int, clear: bool):
57
+ def _run_objective(opt: tz.Optimizer, objective: Callable, use_closure: bool, steps: int, clear: bool):
58
58
  """generic function to run opt on objective and return lowest recorded loss"""
59
59
  losses = []
60
60
  for i in range(steps):
@@ -154,8 +154,8 @@ class Run:
154
154
  Holds arguments for a test.
155
155
 
156
156
  Args:
157
- func_opt (Callable): opt for test function e.g. :code:`lambda p: tz.Modular(p, tz.m.Adam())`
158
- sphere_opt (Callable): opt for sphere e.g. :code:`lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.1))`
157
+ func_opt (Callable): opt for test function e.g. :code:`lambda p: tz.Optimizer(p, tz.m.Adam())`
158
+ sphere_opt (Callable): opt for sphere e.g. :code:`lambda p: tz.Optimizer(p, tz.m.Adam(), tz.m.LR(0.1))`
159
159
  needs_closure (bool): set to True if opt_fn requires closure
160
160
  func (str): what test function to use ("booth", "rosen", "ill")
161
161
  steps (int): number of steps to run test function for.
@@ -176,50 +176,50 @@ class Run:
176
176
  # ---------------------------------------------------------------------------- #
177
177
  # ----------------------------- clipping/clipping ---------------------------- #
178
178
  ClipValue = Run(
179
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValue(1), tz.m.LR(1)),
180
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValue(1), tz.m.LR(1)),
179
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValue(1), tz.m.LR(1)),
180
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValue(1), tz.m.LR(1)),
181
181
  needs_closure=False,
182
182
  func='booth', steps=50, loss=0, merge_invariant=True,
183
183
  sphere_steps=10, sphere_loss=50,
184
184
  )
185
185
  ClipNorm = Run(
186
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1), tz.m.LR(1)),
187
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1), tz.m.LR(0.5)),
186
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1), tz.m.LR(1)),
187
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1), tz.m.LR(0.5)),
188
188
  needs_closure=False,
189
189
  func='booth', steps=50, loss=2, merge_invariant=False,
190
190
  sphere_steps=10, sphere_loss=0,
191
191
  )
192
192
  ClipNorm_global = Run(
193
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(1)),
194
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(3)),
193
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(1)),
194
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(3)),
195
195
  needs_closure=False,
196
196
  func='booth', steps=50, loss=2, merge_invariant=True,
197
197
  sphere_steps=10, sphere_loss=2,
198
198
  )
199
199
  Normalize = Run(
200
- func_opt=lambda p: tz.Modular(p, tz.m.Normalize(1), tz.m.LR(1)),
201
- sphere_opt=lambda p: tz.Modular(p, tz.m.Normalize(1), tz.m.LR(0.5)),
200
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1), tz.m.LR(1)),
201
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1), tz.m.LR(0.5)),
202
202
  needs_closure=False,
203
203
  func='booth', steps=50, loss=2, merge_invariant=False,
204
204
  sphere_steps=10, sphere_loss=15,
205
205
  )
206
206
  Normalize_global = Run(
207
- func_opt=lambda p: tz.Modular(p, tz.m.Normalize(1, dim='global'), tz.m.LR(1)),
208
- sphere_opt=lambda p: tz.Modular(p, tz.m.Normalize(1, dim='global'), tz.m.LR(4)),
207
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1, dim='global'), tz.m.LR(1)),
208
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1, dim='global'), tz.m.LR(4)),
209
209
  needs_closure=False,
210
210
  func='booth', steps=50, loss=2, merge_invariant=True,
211
211
  sphere_steps=10, sphere_loss=2,
212
212
  )
213
213
  Centralize = Run(
214
- func_opt=lambda p: tz.Modular(p, tz.m.Centralize(min_size=3), tz.m.LR(0.1)),
215
- sphere_opt=lambda p: tz.Modular(p, tz.m.Centralize(), tz.m.LR(0.1)),
214
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(min_size=3), tz.m.LR(0.1)),
215
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(), tz.m.LR(0.1)),
216
216
  needs_closure=False,
217
217
  func='booth', steps=50, loss=1e-6, merge_invariant=False,
218
218
  sphere_steps=10, sphere_loss=10,
219
219
  )
220
220
  Centralize_global = Run(
221
- func_opt=lambda p: tz.Modular(p, tz.m.Centralize(min_size=3, dim='global'), tz.m.LR(0.1)),
222
- sphere_opt=lambda p: tz.Modular(p, tz.m.Centralize(dim='global'), tz.m.LR(0.1)),
221
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(min_size=3, dim='global'), tz.m.LR(0.1)),
222
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(dim='global'), tz.m.LR(0.1)),
223
223
  needs_closure=False,
224
224
  func='booth', steps=1, loss=1000, merge_invariant=True,
225
225
  sphere_steps=10, sphere_loss=10,
@@ -227,72 +227,72 @@ Centralize_global = Run(
227
227
 
228
228
  # --------------------------- clipping/ema_clipping -------------------------- #
229
229
  ClipNormByEMA = Run(
230
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(), tz.m.LR(0.1)),
231
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(), tz.m.LR(5)),
230
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(), tz.m.LR(0.1)),
231
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(), tz.m.LR(5)),
232
232
  needs_closure=False,
233
233
  func='booth', steps=50, loss=1e-5, merge_invariant=False,
234
234
  sphere_steps=10, sphere_loss=0.1,
235
235
  )
236
236
  ClipNormByEMA_global = Run(
237
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(0.1)),
238
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(5)),
237
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(0.1)),
238
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(5)),
239
239
  needs_closure=False,
240
240
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
241
241
  sphere_steps=10, sphere_loss=0.1,
242
242
  )
243
243
  NormalizeByEMA = Run(
244
- func_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(), tz.m.LR(0.05)),
245
- sphere_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(), tz.m.LR(5)),
244
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(), tz.m.LR(0.05)),
245
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(), tz.m.LR(5)),
246
246
  needs_closure=False,
247
247
  func='booth', steps=50, loss=1, merge_invariant=False,
248
248
  sphere_steps=10, sphere_loss=0.1,
249
249
  )
250
250
  NormalizeByEMA_global = Run(
251
- func_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(0.05)),
252
- sphere_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(5)),
251
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(0.05)),
252
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(5)),
253
253
  needs_closure=False,
254
254
  func='booth', steps=50, loss=1, merge_invariant=True,
255
255
  sphere_steps=10, sphere_loss=0.1,
256
256
  )
257
257
  ClipValueByEMA = Run(
258
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValueByEMA(), tz.m.LR(0.1)),
259
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueByEMA(), tz.m.LR(4)),
258
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueByEMA(), tz.m.LR(0.1)),
259
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueByEMA(), tz.m.LR(4)),
260
260
  needs_closure=False,
261
261
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
262
262
  sphere_steps=10, sphere_loss=0.03,
263
263
  )
264
264
  # ------------------------- clipping/growth_clipping ------------------------- #
265
265
  ClipValueGrowth = Run(
266
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
267
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
266
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
267
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
268
268
  needs_closure=False,
269
269
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
270
270
  sphere_steps=10, sphere_loss=100,
271
271
  )
272
272
  ClipValueGrowth_additive = Run(
273
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
274
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
273
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
274
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
275
275
  needs_closure=False,
276
276
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
277
277
  sphere_steps=10, sphere_loss=10,
278
278
  )
279
279
  ClipNormGrowth = Run(
280
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
281
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
280
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
281
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
282
282
  needs_closure=False,
283
283
  func='booth', steps=50, loss=1e-6, merge_invariant=False,
284
284
  sphere_steps=10, sphere_loss=10,
285
285
  )
286
286
  ClipNormGrowth_additive = Run(
287
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
288
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
287
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
288
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
289
289
  needs_closure=False,
290
290
  func='booth', steps=50, loss=1e-6, merge_invariant=False,
291
291
  sphere_steps=10, sphere_loss=10,
292
292
  )
293
293
  ClipNormGrowth_global = Run(
294
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
295
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
294
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
295
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
296
296
  needs_closure=False,
297
297
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
298
298
  sphere_steps=10, sphere_loss=10,
@@ -300,43 +300,43 @@ ClipNormGrowth_global = Run(
300
300
 
301
301
  # -------------------------- grad_approximation/fdm -------------------------- #
302
302
  FDM_central2 = Run(
303
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
304
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
303
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
304
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(), tz.m.LR(0.1)),
305
305
  needs_closure=True,
306
306
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
307
307
  sphere_steps=2, sphere_loss=340,
308
308
  )
309
309
  FDM_forward2 = Run(
310
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
311
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
310
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
311
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
312
312
  needs_closure=True,
313
313
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
314
314
  sphere_steps=2, sphere_loss=340,
315
315
  )
316
316
  FDM_backward2 = Run(
317
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
318
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
317
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
318
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
319
319
  needs_closure=True,
320
320
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
321
321
  sphere_steps=2, sphere_loss=340,
322
322
  )
323
323
  FDM_forward3 = Run(
324
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
325
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
324
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
325
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
326
326
  needs_closure=True,
327
327
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
328
328
  sphere_steps=2, sphere_loss=340,
329
329
  )
330
330
  FDM_backward3 = Run(
331
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
332
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
331
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
332
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
333
333
  needs_closure=True,
334
334
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
335
335
  sphere_steps=2, sphere_loss=340,
336
336
  )
337
337
  FDM_central4 = Run(
338
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
339
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
338
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
339
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
340
340
  needs_closure=True,
341
341
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
342
342
  sphere_steps=2, sphere_loss=340,
@@ -344,57 +344,57 @@ FDM_central4 = Run(
344
344
 
345
345
  # -------------------------- grad_approximation/rfdm ------------------------- #
346
346
  RandomizedFDM_central2 = Run(
347
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.01)),
348
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.001)),
347
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.01)),
348
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.001)),
349
349
  needs_closure=True,
350
350
  func='booth', steps=50, loss=10, merge_invariant=True,
351
351
  sphere_steps=200, sphere_loss=420,
352
352
  )
353
353
  RandomizedFDM_forward2 = Run(
354
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.01)),
355
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.001)),
354
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.01)),
355
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.001)),
356
356
  needs_closure=True,
357
357
  func='booth', steps=50, loss=10, merge_invariant=True,
358
358
  sphere_steps=200, sphere_loss=420,
359
359
  )
360
360
  RandomizedFDM_backward2 = Run(
361
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.01)),
362
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.001)),
361
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.01)),
362
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.001)),
363
363
  needs_closure=True,
364
364
  func='booth', steps=50, loss=10, merge_invariant=True,
365
365
  sphere_steps=200, sphere_loss=420,
366
366
  )
367
367
  RandomizedFDM_forward3 = Run(
368
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.01)),
369
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.001)),
368
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.01)),
369
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.001)),
370
370
  needs_closure=True,
371
371
  func='booth', steps=50, loss=10, merge_invariant=True,
372
372
  sphere_steps=200, sphere_loss=420,
373
373
  )
374
374
  RandomizedFDM_backward3 = Run(
375
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.01)),
376
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.001)),
375
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.01)),
376
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.001)),
377
377
  needs_closure=True,
378
378
  func='booth', steps=50, loss=10, merge_invariant=True,
379
379
  sphere_steps=200, sphere_loss=420,
380
380
  )
381
381
  RandomizedFDM_central4 = Run(
382
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.01)),
383
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.001)),
382
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.01)),
383
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.001)),
384
384
  needs_closure=True,
385
385
  func='booth', steps=50, loss=10, merge_invariant=True,
386
386
  sphere_steps=200, sphere_loss=420,
387
387
  )
388
388
  RandomizedFDM_forward4 = Run(
389
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
390
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
389
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
390
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
391
391
  needs_closure=True,
392
392
  func='booth', steps=50, loss=10, merge_invariant=True,
393
393
  sphere_steps=200, sphere_loss=420,
394
394
  )
395
395
  RandomizedFDM_forward5 = Run(
396
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
397
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
396
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
397
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
398
398
  needs_closure=True,
399
399
  func='booth', steps=50, loss=10, merge_invariant=True,
400
400
  sphere_steps=200, sphere_loss=420,
@@ -402,65 +402,65 @@ RandomizedFDM_forward5 = Run(
402
402
 
403
403
 
404
404
  RandomizedFDM_4samples = Run(
405
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
406
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.001)),
405
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
406
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.001)),
407
407
  needs_closure=True,
408
408
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
409
409
  sphere_steps=100, sphere_loss=400,
410
410
  )
411
411
  RandomizedFDM_4samples_no_pre_generate = Run(
412
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
413
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
412
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
413
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
414
414
  needs_closure=True,
415
415
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
416
416
  sphere_steps=100, sphere_loss=400,
417
417
  )
418
418
  MeZO = Run(
419
- func_opt=lambda p: tz.Modular(p, tz.m.MeZO(), tz.m.LR(0.01)),
420
- sphere_opt=lambda p: tz.Modular(p, tz.m.MeZO(), tz.m.LR(0.001)),
419
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(), tz.m.LR(0.01)),
420
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(), tz.m.LR(0.001)),
421
421
  needs_closure=True,
422
422
  func='booth', steps=50, loss=5, merge_invariant=True,
423
423
  sphere_steps=100, sphere_loss=450,
424
424
  )
425
425
  MeZO_4samples = Run(
426
- func_opt=lambda p: tz.Modular(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.02)),
427
- sphere_opt=lambda p: tz.Modular(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.005)),
426
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.02)),
427
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.005)),
428
428
  needs_closure=True,
429
429
  func='booth', steps=50, loss=1, merge_invariant=True,
430
430
  sphere_steps=100, sphere_loss=250,
431
431
  )
432
432
  # -------------------- grad_approximation/forward_gradient ------------------- #
433
433
  ForwardGradient = Run(
434
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.01)),
435
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.001)),
434
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.01)),
435
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.001)),
436
436
  needs_closure=True,
437
437
  func='booth', steps=50, loss=40, merge_invariant=True,
438
438
  sphere_steps=200, sphere_loss=450,
439
439
  )
440
440
  ForwardGradient_forward = Run(
441
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.01)),
442
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.001)),
441
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.01)),
442
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.001)),
443
443
  needs_closure=True,
444
444
  func='booth', steps=50, loss=40, merge_invariant=True,
445
445
  sphere_steps=200, sphere_loss=450,
446
446
  )
447
447
  ForwardGradient_central = Run(
448
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.01)),
449
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.001)),
448
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.01)),
449
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.001)),
450
450
  needs_closure=True,
451
451
  func='booth', steps=50, loss=40, merge_invariant=True,
452
452
  sphere_steps=200, sphere_loss=450,
453
453
  )
454
454
  ForwardGradient_4samples = Run(
455
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.1)),
456
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.001)),
455
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.1)),
456
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.001)),
457
457
  needs_closure=True,
458
458
  func='booth', steps=50, loss=0.1, merge_invariant=True,
459
459
  sphere_steps=100, sphere_loss=420,
460
460
  )
461
461
  ForwardGradient_4samples_no_pre_generate = Run(
462
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.1)),
463
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.001)),
462
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.1)),
463
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.001)),
464
464
  needs_closure=True,
465
465
  func='booth', steps=50, loss=0.1, merge_invariant=True,
466
466
  sphere_steps=100, sphere_loss=420,
@@ -468,23 +468,23 @@ ForwardGradient_4samples_no_pre_generate = Run(
468
468
 
469
469
  # ------------------------- line_search/backtracking ------------------------- #
470
470
  Backtracking = Run(
471
- func_opt=lambda p: tz.Modular(p, tz.m.Backtracking()),
472
- sphere_opt=lambda p: tz.Modular(p, tz.m.Backtracking()),
471
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Backtracking()),
472
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Backtracking()),
473
473
  needs_closure=True,
474
474
  func='booth', steps=50, loss=0, merge_invariant=True,
475
475
  sphere_steps=2, sphere_loss=0,
476
476
  )
477
477
  AdaptiveBacktracking = Run(
478
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
479
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
478
+ func_opt=lambda p: tz.Optimizer(p, tz.m.AdaptiveBacktracking()),
479
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.AdaptiveBacktracking()),
480
480
  needs_closure=True,
481
481
  func='booth', steps=50, loss=1e-11, merge_invariant=True,
482
482
  sphere_steps=2, sphere_loss=1e-10,
483
483
  )
484
484
  # ----------------------------- line_search/scipy ---------------------------- #
485
485
  ScipyMinimizeScalar = Run(
486
- func_opt=lambda p: tz.Modular(p, tz.m.ScipyMinimizeScalar(maxiter=10)),
487
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking(maxiter=10)),
486
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ScipyMinimizeScalar(maxiter=10)),
487
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.AdaptiveBacktracking(maxiter=10)),
488
488
  needs_closure=True,
489
489
  func='booth', steps=50, loss=1e-2, merge_invariant=True,
490
490
  sphere_steps=2, sphere_loss=0,
@@ -492,8 +492,8 @@ ScipyMinimizeScalar = Run(
492
492
 
493
493
  # ------------------------- line_search/strong_wolfe ------------------------- #
494
494
  StrongWolfe = Run(
495
- func_opt=lambda p: tz.Modular(p, tz.m.StrongWolfe()),
496
- sphere_opt=lambda p: tz.Modular(p, tz.m.StrongWolfe()),
495
+ func_opt=lambda p: tz.Optimizer(p, tz.m.StrongWolfe()),
496
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.StrongWolfe()),
497
497
  needs_closure=True,
498
498
  func='booth', steps=50, loss=0, merge_invariant=True,
499
499
  sphere_steps=2, sphere_loss=0,
@@ -501,44 +501,44 @@ StrongWolfe = Run(
501
501
 
502
502
  # ----------------------------------- lr/lr ---------------------------------- #
503
503
  LR = Run(
504
- func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
505
- sphere_opt=lambda p: tz.Modular(p, tz.m.LR(0.5)),
504
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LR(0.1)),
505
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LR(0.5)),
506
506
  needs_closure=False,
507
507
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
508
508
  sphere_steps=10, sphere_loss=0,
509
509
  )
510
510
  StepSize = Run(
511
- func_opt=lambda p: tz.Modular(p, tz.m.StepSize(0.1)),
512
- sphere_opt=lambda p: tz.Modular(p, tz.m.StepSize(0.5)),
511
+ func_opt=lambda p: tz.Optimizer(p, tz.m.StepSize(0.1)),
512
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.StepSize(0.5)),
513
513
  needs_closure=False,
514
514
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
515
515
  sphere_steps=10, sphere_loss=0,
516
516
  )
517
517
  Warmup = Run(
518
- func_opt=lambda p: tz.Modular(p, tz.m.Warmup(steps=50, end_lr=0.1)),
519
- sphere_opt=lambda p: tz.Modular(p, tz.m.Warmup(steps=10)),
518
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Warmup(steps=50, end_lr=0.1)),
519
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Warmup(steps=10)),
520
520
  needs_closure=False,
521
521
  func='booth', steps=50, loss=0.003, merge_invariant=True,
522
522
  sphere_steps=10, sphere_loss=0.05,
523
523
  )
524
524
  # ------------------------------- lr/step_size ------------------------------- #
525
525
  PolyakStepSize = Run(
526
- func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
527
- sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
526
+ func_opt=lambda p: tz.Optimizer(p, tz.m.PolyakStepSize()),
527
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.PolyakStepSize()),
528
528
  needs_closure=True,
529
529
  func='booth', steps=50, loss=1e-7, merge_invariant=True,
530
530
  sphere_steps=10, sphere_loss=0.002,
531
531
  )
532
532
  RandomStepSize = Run(
533
- func_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, seed=0)),
534
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, seed=0)),
533
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, seed=0)),
534
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, seed=0)),
535
535
  needs_closure=False,
536
536
  func='booth', steps=50, loss=0.0005, merge_invariant=True,
537
537
  sphere_steps=10, sphere_loss=100,
538
538
  )
539
539
  RandomStepSize_parameterwise = Run(
540
- func_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
541
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
540
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
541
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
542
542
  needs_closure=False,
543
543
  func='booth', steps=50, loss=0.0005, merge_invariant=False,
544
544
  sphere_steps=10, sphere_loss=100,
@@ -546,22 +546,22 @@ RandomStepSize_parameterwise = Run(
546
546
 
547
547
  # ---------------------------- momentum/averaging ---------------------------- #
548
548
  Averaging = Run(
549
- func_opt=lambda p: tz.Modular(p, tz.m.Averaging(10), tz.m.LR(0.02)),
550
- sphere_opt=lambda p: tz.Modular(p, tz.m.Averaging(10), tz.m.LR(0.2)),
549
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Averaging(10), tz.m.LR(0.02)),
550
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Averaging(10), tz.m.LR(0.2)),
551
551
  needs_closure=False,
552
552
  func='booth', steps=50, loss=0.5, merge_invariant=True,
553
553
  sphere_steps=10, sphere_loss=0.05,
554
554
  )
555
555
  WeightedAveraging = Run(
556
- func_opt=lambda p: tz.Modular(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.05)),
557
- sphere_opt=lambda p: tz.Modular(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.5)),
556
+ func_opt=lambda p: tz.Optimizer(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.05)),
557
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.5)),
558
558
  needs_closure=False,
559
559
  func='booth', steps=50, loss=1, merge_invariant=True,
560
560
  sphere_steps=10, sphere_loss=2,
561
561
  )
562
562
  MedianAveraging = Run(
563
- func_opt=lambda p: tz.Modular(p, tz.m.MedianAveraging(10), tz.m.LR(0.05)),
564
- sphere_opt=lambda p: tz.Modular(p, tz.m.MedianAveraging(10), tz.m.LR(0.5)),
563
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MedianAveraging(10), tz.m.LR(0.05)),
564
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MedianAveraging(10), tz.m.LR(0.5)),
565
565
  needs_closure=False,
566
566
  func='booth', steps=50, loss=0.005, merge_invariant=True,
567
567
  sphere_steps=10, sphere_loss=0,
@@ -569,36 +569,36 @@ MedianAveraging = Run(
569
569
 
570
570
  # ----------------------------- momentum/cautious ---------------------------- #
571
571
  Cautious = Run(
572
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
573
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
572
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
573
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
574
574
  needs_closure=False,
575
575
  func='booth', steps=50, loss=0.003, merge_invariant=True,
576
576
  sphere_steps=10, sphere_loss=2,
577
577
  )
578
578
  UpdateGradientSignConsistency = Run(
579
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
580
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
579
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
580
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
581
581
  needs_closure=False,
582
582
  func='booth', steps=50, loss=0.003, merge_invariant=True,
583
583
  sphere_steps=10, sphere_loss=2,
584
584
  )
585
585
  IntermoduleCautious = Run(
586
- func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
587
- sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
586
+ func_opt=lambda p: tz.Optimizer(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
587
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
588
588
  needs_closure=False,
589
589
  func='booth', steps=50, loss=1e-4, merge_invariant=True,
590
590
  sphere_steps=10, sphere_loss=0.1,
591
591
  )
592
592
  ScaleByGradCosineSimilarity = Run(
593
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.01)),
594
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.1)),
593
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.01)),
594
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.1)),
595
595
  needs_closure=False,
596
596
  func='booth', steps=50, loss=0.1, merge_invariant=True,
597
597
  sphere_steps=10, sphere_loss=0.1,
598
598
  )
599
599
  ScaleModulesByCosineSimilarity = Run(
600
- func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
601
- sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
600
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
601
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
602
602
  needs_closure=False,
603
603
  func='booth', steps=50, loss=0.005, merge_invariant=True,
604
604
  sphere_steps=10, sphere_loss=0.1,
@@ -606,66 +606,66 @@ ScaleModulesByCosineSimilarity = Run(
606
606
 
607
607
  # ------------------------- momentum/matrix_momentum ------------------------- #
608
608
  MatrixMomentum_forward = Run(
609
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_forward'),),
610
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward')),
609
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_forward'),),
610
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward')),
611
611
  needs_closure=True,
612
612
  func='booth', steps=50, loss=0.05, merge_invariant=True,
613
613
  sphere_steps=10, sphere_loss=0.01,
614
614
  )
615
615
  MatrixMomentum_forward = Run(
616
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_central')),
617
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central')),
616
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_central')),
617
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central')),
618
618
  needs_closure=True,
619
619
  func='booth', steps=50, loss=0.05, merge_invariant=True,
620
620
  sphere_steps=10, sphere_loss=0.01,
621
621
  )
622
622
  MatrixMomentum_forward = Run(
623
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
624
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
623
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
624
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
625
625
  needs_closure=True,
626
626
  func='booth', steps=50, loss=0.05, merge_invariant=True,
627
627
  sphere_steps=10, sphere_loss=0.01,
628
628
  )
629
629
 
630
630
  AdaptiveMatrixMomentum_forward = Run(
631
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True)),
632
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True)),
631
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True)),
632
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True)),
633
633
  needs_closure=True,
634
634
  func='booth', steps=50, loss=0.05, merge_invariant=True,
635
635
  sphere_steps=10, sphere_loss=0.05,
636
636
  )
637
637
  AdaptiveMatrixMomentum_central = Run(
638
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True)),
639
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True)),
638
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True)),
639
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True)),
640
640
  needs_closure=True,
641
641
  func='booth', steps=50, loss=0.05, merge_invariant=True,
642
642
  sphere_steps=10, sphere_loss=0.05,
643
643
  )
644
644
  AdaptiveMatrixMomentum_autograd = Run(
645
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
646
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
645
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
646
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
647
647
  needs_closure=True,
648
648
  func='booth', steps=50, loss=0.05, merge_invariant=True,
649
649
  sphere_steps=10, sphere_loss=0.05,
650
650
  )
651
651
 
652
652
  StochasticAdaptiveMatrixMomentum_forward = Run(
653
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
654
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
653
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
654
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
655
655
  needs_closure=True,
656
656
  func='booth', steps=50, loss=0.05, merge_invariant=True,
657
657
  sphere_steps=10, sphere_loss=0.05,
658
658
  )
659
659
  StochasticAdaptiveMatrixMomentum_central = Run(
660
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
661
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
660
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
661
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
662
662
  needs_closure=True,
663
663
  func='booth', steps=50, loss=0.05, merge_invariant=True,
664
664
  sphere_steps=10, sphere_loss=0.05,
665
665
  )
666
666
  StochasticAdaptiveMatrixMomentum_autograd = Run(
667
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
668
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
667
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
668
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
669
669
  needs_closure=True,
670
670
  func='booth', steps=50, loss=0.05, merge_invariant=True,
671
671
  sphere_steps=10, sphere_loss=0.05,
@@ -674,44 +674,44 @@ StochasticAdaptiveMatrixMomentum_autograd = Run(
674
674
  # EMA, momentum are covered by test_identical
675
675
  # --------------------------------- ops/misc --------------------------------- #
676
676
  Previous = Run(
677
- func_opt=lambda p: tz.Modular(p, tz.m.Previous(10), tz.m.LR(0.05)),
678
- sphere_opt=lambda p: tz.Modular(p, tz.m.Previous(3), tz.m.LR(0.5)),
677
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Previous(10), tz.m.LR(0.05)),
678
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Previous(3), tz.m.LR(0.5)),
679
679
  needs_closure=False,
680
680
  func='booth', steps=50, loss=15, merge_invariant=True,
681
681
  sphere_steps=10, sphere_loss=0,
682
682
  )
683
683
  GradSign = Run(
684
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.05)),
685
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.5)),
684
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.05)),
685
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.5)),
686
686
  needs_closure=False,
687
687
  func='booth', steps=50, loss=0.0002, merge_invariant=True,
688
688
  sphere_steps=10, sphere_loss=0.1,
689
689
  )
690
690
  UpdateSign = Run(
691
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.05)),
692
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.5)),
691
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.05)),
692
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.5)),
693
693
  needs_closure=False,
694
694
  func='booth', steps=50, loss=0.01, merge_invariant=True,
695
695
  sphere_steps=10, sphere_loss=0,
696
696
  )
697
697
  GradAccumulation = Run(
698
- func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
699
- sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
698
+ func_opt=lambda p: tz.Optimizer(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
699
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
700
700
  needs_closure=False,
701
701
  func='booth', steps=50, loss=25, merge_invariant=True,
702
702
  sphere_steps=20, sphere_loss=1e-11,
703
703
  )
704
704
  NegateOnLossIncrease = Run(
705
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
706
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
705
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
706
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
707
707
  needs_closure=True,
708
708
  func='booth', steps=50, loss=0.1, merge_invariant=True,
709
709
  sphere_steps=20, sphere_loss=0.001,
710
710
  )
711
711
  # -------------------------------- misc/switch ------------------------------- #
712
712
  Alternate = Run(
713
- func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
714
- sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
713
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
714
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
715
715
  needs_closure=False,
716
716
  func='booth', steps=50, loss=1, merge_invariant=True,
717
717
  sphere_steps=20, sphere_loss=20,
@@ -719,16 +719,16 @@ Alternate = Run(
719
719
 
720
720
  # ------------------------------ optimizers/adam ----------------------------- #
721
721
  Adam = Run(
722
- func_opt=lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.5)),
723
- sphere_opt=lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.2)),
722
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Adam(), tz.m.LR(0.5)),
723
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Adam(), tz.m.LR(0.2)),
724
724
  needs_closure=False,
725
725
  func='rosen', steps=50, loss=4, merge_invariant=True,
726
726
  sphere_steps=20, sphere_loss=4,
727
727
  )
728
728
  # ------------------------------ optimizers/soap ----------------------------- #
729
729
  SOAP = Run(
730
- func_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(0.4)),
731
- sphere_opt=lambda p: tz.Modular(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
730
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(0.4)),
731
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
732
732
  needs_closure=False,
733
733
  # merge and unmerge lrs are very different so need to test convergence separately somewhere
734
734
  func='rosen', steps=50, loss=4, merge_invariant=False,
@@ -736,16 +736,16 @@ SOAP = Run(
736
736
  )
737
737
  # ------------------------------ optimizers/lion ----------------------------- #
738
738
  Lion = Run(
739
- func_opt=lambda p: tz.Modular(p, tz.m.Lion(), tz.m.LR(1)),
740
- sphere_opt=lambda p: tz.Modular(p, tz.m.Lion(), tz.m.LR(0.1)),
739
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Lion(), tz.m.LR(1)),
740
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Lion(), tz.m.LR(0.1)),
741
741
  needs_closure=False,
742
742
  func='booth', steps=50, loss=0, merge_invariant=True,
743
743
  sphere_steps=20, sphere_loss=25,
744
744
  )
745
745
  # ---------------------------- optimizers/shampoo ---------------------------- #
746
746
  Shampoo = Run(
747
- func_opt=lambda p: tz.Modular(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
748
- sphere_opt=lambda p: tz.Modular(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
747
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
748
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
749
749
  needs_closure=False,
750
750
  # merge and unmerge lrs are very different so need to test convergence separately somewhere
751
751
  func='booth', steps=50, loss=0.02, merge_invariant=False,
@@ -754,32 +754,33 @@ Shampoo = Run(
754
754
 
755
755
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
756
756
  BFGS = Run(
757
- func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
758
- sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
757
+ func_opt=lambda p: tz.Optimizer(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
758
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
759
759
  needs_closure=True,
760
760
  func='rosen', steps=50, loss=1e-10, merge_invariant=True,
761
761
  sphere_steps=10, sphere_loss=1e-10,
762
762
  )
763
763
  SR1 = Run(
764
- func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_restart=True, scale_first=True), tz.m.StrongWolfe(fallback=False)),
765
- sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
764
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SR1(ptol_restart=True), tz.m.StrongWolfe(c2=0.1)),
765
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(c2=0.1)),
766
766
  needs_closure=True,
767
767
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
768
768
  # this reaches 1e-13 on github so don't change to 0
769
769
  sphere_steps=10, sphere_loss=0,
770
770
  )
771
771
  SSVM = Run(
772
- func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
773
- sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
772
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SSVM(1), tz.m.StrongWolfe(fallback=True)),
773
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SSVM(1), tz.m.StrongWolfe(fallback=True)),
774
774
  needs_closure=True,
775
+ # this reaches 0.12 on github so don't change to 0.002
775
776
  func='rosen', steps=50, loss=0.2, merge_invariant=True,
776
777
  sphere_steps=10, sphere_loss=0,
777
778
  )
778
779
 
779
780
  # ---------------------------- quasi_newton/lbfgs ---------------------------- #
780
781
  LBFGS = Run(
781
- func_opt=lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
782
- sphere_opt=lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
782
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
783
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
783
784
  needs_closure=True,
784
785
  func='rosen', steps=50, loss=0, merge_invariant=True,
785
786
  sphere_steps=10, sphere_loss=0,
@@ -787,8 +788,8 @@ LBFGS = Run(
787
788
 
788
789
  # ----------------------------- quasi_newton/lsr1 ---------------------------- #
789
790
  LSR1 = Run(
790
- func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
791
- sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
791
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
792
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
792
793
  needs_closure=True,
793
794
  func='rosen', steps=50, loss=0, merge_invariant=True,
794
795
  sphere_steps=10, sphere_loss=0,
@@ -796,8 +797,8 @@ LSR1 = Run(
796
797
 
797
798
  # # ---------------------------- quasi_newton/olbfgs --------------------------- #
798
799
  # OnlineLBFGS = Run(
799
- # func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
800
- # sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
800
+ # func_opt=lambda p: tz.Optimizer(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
801
+ # sphere_opt=lambda p: tz.Optimizer(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
801
802
  # needs_closure=True,
802
803
  # func='rosen', steps=50, loss=0, merge_invariant=True,
803
804
  # sphere_steps=10, sphere_loss=0,
@@ -805,8 +806,8 @@ LSR1 = Run(
805
806
 
806
807
  # ---------------------------- second_order/newton --------------------------- #
807
808
  Newton = Run(
808
- func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
809
- sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
809
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
810
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
810
811
  needs_closure=True,
811
812
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
812
813
  sphere_steps=2, sphere_loss=1e-9,
@@ -814,8 +815,8 @@ Newton = Run(
814
815
 
815
816
  # --------------------------- second_order/newton_cg -------------------------- #
816
817
  NewtonCG = Run(
817
- func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
818
- sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
818
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
819
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
819
820
  needs_closure=True,
820
821
  func='rosen', steps=20, loss=1e-10, merge_invariant=True,
821
822
  sphere_steps=2, sphere_loss=3e-4,
@@ -823,8 +824,8 @@ NewtonCG = Run(
823
824
 
824
825
  # ---------------------------- smoothing/gaussian ---------------------------- #
825
826
  GaussianHomotopy = Run(
826
- func_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
827
- sphere_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
827
+ func_opt=lambda p: tz.Optimizer(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
828
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
828
829
  needs_closure=True,
829
830
  func='booth', steps=20, loss=0.01, merge_invariant=True,
830
831
  sphere_steps=10, sphere_loss=1,
@@ -832,16 +833,16 @@ GaussianHomotopy = Run(
832
833
 
833
834
  # ---------------------------- smoothing/laplacian --------------------------- #
834
835
  LaplacianSmoothing = Run(
835
- func_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.1)),
836
- sphere_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.5)),
836
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.1)),
837
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.5)),
837
838
  needs_closure=False,
838
839
  func='booth', steps=50, loss=0.4, merge_invariant=False,
839
840
  sphere_steps=10, sphere_loss=3,
840
841
  )
841
842
 
842
843
  LaplacianSmoothing_global = Run(
843
- func_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.1)),
844
- sphere_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.5)),
844
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.1)),
845
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.5)),
845
846
  needs_closure=False,
846
847
  func='booth', steps=50, loss=0.4, merge_invariant=True,
847
848
  sphere_steps=10, sphere_loss=3,
@@ -849,8 +850,8 @@ LaplacianSmoothing_global = Run(
849
850
 
850
851
  # -------------------------- wrappers/optim_wrapper -------------------------- #
851
852
  Wrap = Run(
852
- func_opt=lambda p: tz.Modular(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.5)),
853
- sphere_opt=lambda p: tz.Modular(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.2)),
853
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.5)),
854
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.2)),
854
855
  needs_closure=False,
855
856
  func='rosen', steps=50, loss=4, merge_invariant=True,
856
857
  sphere_steps=20, sphere_loss=4,
@@ -858,15 +859,15 @@ Wrap = Run(
858
859
 
859
860
  # --------------------------- second_order/nystrom --------------------------- #
860
861
  NystromSketchAndSolve = Run(
861
- func_opt=lambda p: tz.Modular(p, tz.m.NystromSketchAndSolve(2, seed=0), tz.m.StrongWolfe()),
862
- sphere_opt=lambda p: tz.Modular(p, tz.m.NystromSketchAndSolve(10, seed=0), tz.m.StrongWolfe()),
862
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NystromSketchAndSolve(2, seed=0), tz.m.StrongWolfe()),
863
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NystromSketchAndSolve(10, seed=0), tz.m.StrongWolfe()),
863
864
  needs_closure=True,
864
865
  func='booth', steps=3, loss=1e-6, merge_invariant=True,
865
866
  sphere_steps=10, sphere_loss=1e-12,
866
867
  )
867
868
  NystromPCG = Run(
868
- func_opt=lambda p: tz.Modular(p, tz.m.NystromPCG(2, seed=0), tz.m.StrongWolfe()),
869
- sphere_opt=lambda p: tz.Modular(p, tz.m.NystromPCG(10, seed=0), tz.m.StrongWolfe()),
869
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NystromPCG(2, seed=0), tz.m.StrongWolfe()),
870
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NystromPCG(10, seed=0), tz.m.StrongWolfe()),
870
871
  needs_closure=True,
871
872
  func='ill', steps=2, loss=1e-5, merge_invariant=True,
872
873
  sphere_steps=2, sphere_loss=1e-9,
@@ -874,8 +875,8 @@ NystromPCG = Run(
874
875
 
875
876
  # ---------------------------- optimizers/sophia_h --------------------------- #
876
877
  SophiaH = Run(
877
- func_opt=lambda p: tz.Modular(p, tz.m.SophiaH(seed=0), tz.m.LR(0.1)),
878
- sphere_opt=lambda p: tz.Modular(p, tz.m.SophiaH(seed=0), tz.m.LR(0.3)),
878
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SophiaH(seed=0), tz.m.LR(0.1)),
879
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SophiaH(seed=0), tz.m.LR(0.3)),
879
880
  needs_closure=True,
880
881
  func='ill', steps=50, loss=0.02, merge_invariant=True,
881
882
  sphere_steps=10, sphere_loss=40,
@@ -883,17 +884,17 @@ SophiaH = Run(
883
884
 
884
885
  # -------------------------- higher_order ------------------------- #
885
886
  HigherOrderNewton = Run(
886
- func_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(trust_method=None)),
887
- sphere_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(2, trust_method=None)),
887
+ func_opt=lambda p: tz.Optimizer(p, tz.m.experimental.HigherOrderNewton(trust_method=None)),
888
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.experimental.HigherOrderNewton(2, trust_method=None)),
888
889
  needs_closure=True,
889
890
  func='rosen', steps=1, loss=2e-10, merge_invariant=True,
890
891
  sphere_steps=1, sphere_loss=1e-10,
891
892
  )
892
893
 
893
894
  # ---------------------------- optimizers/ladagrad --------------------------- #
894
- LMAdagrad = Run(
895
- func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
896
- sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
895
+ GGT = Run(
896
+ func_opt=lambda p: tz.Optimizer(p, tz.m.GGT(), tz.m.LR(4)),
897
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.GGT(), tz.m.LR(5)),
897
898
  needs_closure=False,
898
899
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
899
900
  sphere_steps=20, sphere_loss=1e-9,
@@ -901,8 +902,8 @@ LMAdagrad = Run(
901
902
 
902
903
  # ------------------------------ optimizers/adan ----------------------------- #
903
904
  Adan = Run(
904
- func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
905
- sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
905
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Adan(), tz.m.LR(1)),
906
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Adan(), tz.m.LR(0.1)),
906
907
  needs_closure=False,
907
908
  func='booth', steps=50, loss=60, merge_invariant=True,
908
909
  sphere_steps=20, sphere_loss=60,
@@ -913,8 +914,8 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
913
914
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
914
915
  # but also test 10 to make sure it doesn't explode after converging
915
916
  Run(
916
- func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
917
- sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
917
+ func_opt=lambda p: tz.Optimizer(p, CG(), tz.m.StrongWolfe(c2=0.1)),
918
+ sphere_opt=lambda p: tz.Optimizer(p, CG(), tz.m.StrongWolfe(c2=0.1)),
918
919
  needs_closure=True,
919
920
  func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=True,
920
921
  sphere_steps=sphere_steps_, sphere_loss=0,
@@ -947,8 +948,8 @@ for QN in (
947
948
  tz.m.SSVM,
948
949
  ):
949
950
  Run(
950
- func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
951
- sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
951
+ func_opt=lambda p: tz.Optimizer(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
952
+ sphere_opt=lambda p: tz.Optimizer(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
952
953
  needs_closure=True,
953
954
  func='lstsq', steps=50, loss=1e-10, merge_invariant=True,
954
955
  sphere_steps=10, sphere_loss=1e-20,