torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_opts.py CHANGED
@@ -4,15 +4,23 @@ Sanity tests to make sure everything works.
4
4
  This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
5
5
  don't error or become unhinged with different parameter shapes.
6
6
  """
7
+ import random
7
8
  from collections.abc import Callable
8
9
  from functools import partial
9
10
 
11
+ import numpy as np
10
12
  import pytest
11
13
  import torch
14
+
12
15
  import torchzero as tz
13
16
 
14
17
  PRINT = False # set to true in nbs
15
18
 
19
+ # seed
20
+ torch.manual_seed(0)
21
+ np.random.seed(0)
22
+ random.seed(0)
23
+
16
24
  def _booth(x, y):
17
25
  return (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
18
26
 
@@ -46,12 +54,12 @@ class _TestModel(torch.nn.Module):
46
54
  def forward(self):
47
55
  return torch.sum(torch.stack([p.pow(2).sum() for p in self.params]))
48
56
 
49
- def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, steps: int, clear: bool):
57
+ def _run_objective(opt: tz.Optimizer, objective: Callable, use_closure: bool, steps: int, clear: bool):
50
58
  """generic function to run opt on objective and return lowest recorded loss"""
51
59
  losses = []
52
60
  for i in range(steps):
53
61
  if clear and i == steps//2:
54
- for m in opt.unrolled_modules: m.reset() # clear on middle step to see if there are any issues with it
62
+ for m in opt.flat_modules: m.reset() # clear on middle step to see if there are any issues with it
55
63
 
56
64
  if use_closure:
57
65
  def closure(backward=True):
@@ -146,8 +154,8 @@ class Run:
146
154
  Holds arguments for a test.
147
155
 
148
156
  Args:
149
- func_opt (Callable): opt for test function e.g. :code:`lambda p: tz.Modular(p, tz.m.Adam())`
150
- sphere_opt (Callable): opt for sphere e.g. :code:`lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.1))`
157
+ func_opt (Callable): opt for test function e.g. :code:`lambda p: tz.Optimizer(p, tz.m.Adam())`
158
+ sphere_opt (Callable): opt for sphere e.g. :code:`lambda p: tz.Optimizer(p, tz.m.Adam(), tz.m.LR(0.1))`
151
159
  needs_closure (bool): set to True if opt_fn requires closure
152
160
  func (str): what test function to use ("booth", "rosen", "ill")
153
161
  steps (int): number of steps to run test function for.
@@ -168,50 +176,50 @@ class Run:
168
176
  # ---------------------------------------------------------------------------- #
169
177
  # ----------------------------- clipping/clipping ---------------------------- #
170
178
  ClipValue = Run(
171
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValue(1), tz.m.LR(1)),
172
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValue(1), tz.m.LR(1)),
179
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValue(1), tz.m.LR(1)),
180
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValue(1), tz.m.LR(1)),
173
181
  needs_closure=False,
174
182
  func='booth', steps=50, loss=0, merge_invariant=True,
175
183
  sphere_steps=10, sphere_loss=50,
176
184
  )
177
185
  ClipNorm = Run(
178
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1), tz.m.LR(1)),
179
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1), tz.m.LR(0.5)),
186
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1), tz.m.LR(1)),
187
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1), tz.m.LR(0.5)),
180
188
  needs_closure=False,
181
189
  func='booth', steps=50, loss=2, merge_invariant=False,
182
190
  sphere_steps=10, sphere_loss=0,
183
191
  )
184
192
  ClipNorm_global = Run(
185
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(1)),
186
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(3)),
193
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(1)),
194
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(3)),
187
195
  needs_closure=False,
188
196
  func='booth', steps=50, loss=2, merge_invariant=True,
189
197
  sphere_steps=10, sphere_loss=2,
190
198
  )
191
199
  Normalize = Run(
192
- func_opt=lambda p: tz.Modular(p, tz.m.Normalize(1), tz.m.LR(1)),
193
- sphere_opt=lambda p: tz.Modular(p, tz.m.Normalize(1), tz.m.LR(0.5)),
200
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1), tz.m.LR(1)),
201
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1), tz.m.LR(0.5)),
194
202
  needs_closure=False,
195
203
  func='booth', steps=50, loss=2, merge_invariant=False,
196
204
  sphere_steps=10, sphere_loss=15,
197
205
  )
198
206
  Normalize_global = Run(
199
- func_opt=lambda p: tz.Modular(p, tz.m.Normalize(1, dim='global'), tz.m.LR(1)),
200
- sphere_opt=lambda p: tz.Modular(p, tz.m.Normalize(1, dim='global'), tz.m.LR(4)),
207
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1, dim='global'), tz.m.LR(1)),
208
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Normalize(1, dim='global'), tz.m.LR(4)),
201
209
  needs_closure=False,
202
210
  func='booth', steps=50, loss=2, merge_invariant=True,
203
211
  sphere_steps=10, sphere_loss=2,
204
212
  )
205
213
  Centralize = Run(
206
- func_opt=lambda p: tz.Modular(p, tz.m.Centralize(min_size=3), tz.m.LR(0.1)),
207
- sphere_opt=lambda p: tz.Modular(p, tz.m.Centralize(), tz.m.LR(0.1)),
214
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(min_size=3), tz.m.LR(0.1)),
215
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(), tz.m.LR(0.1)),
208
216
  needs_closure=False,
209
217
  func='booth', steps=50, loss=1e-6, merge_invariant=False,
210
218
  sphere_steps=10, sphere_loss=10,
211
219
  )
212
220
  Centralize_global = Run(
213
- func_opt=lambda p: tz.Modular(p, tz.m.Centralize(min_size=3, dim='global'), tz.m.LR(0.1)),
214
- sphere_opt=lambda p: tz.Modular(p, tz.m.Centralize(dim='global'), tz.m.LR(0.1)),
221
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(min_size=3, dim='global'), tz.m.LR(0.1)),
222
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Centralize(dim='global'), tz.m.LR(0.1)),
215
223
  needs_closure=False,
216
224
  func='booth', steps=1, loss=1000, merge_invariant=True,
217
225
  sphere_steps=10, sphere_loss=10,
@@ -219,72 +227,72 @@ Centralize_global = Run(
219
227
 
220
228
  # --------------------------- clipping/ema_clipping -------------------------- #
221
229
  ClipNormByEMA = Run(
222
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(), tz.m.LR(0.1)),
223
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(), tz.m.LR(5)),
230
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(), tz.m.LR(0.1)),
231
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(), tz.m.LR(5)),
224
232
  needs_closure=False,
225
233
  func='booth', steps=50, loss=1e-5, merge_invariant=False,
226
234
  sphere_steps=10, sphere_loss=0.1,
227
235
  )
228
236
  ClipNormByEMA_global = Run(
229
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(0.1)),
230
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(5)),
237
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(0.1)),
238
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(5)),
231
239
  needs_closure=False,
232
240
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
233
241
  sphere_steps=10, sphere_loss=0.1,
234
242
  )
235
243
  NormalizeByEMA = Run(
236
- func_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(), tz.m.LR(0.05)),
237
- sphere_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(), tz.m.LR(5)),
244
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(), tz.m.LR(0.05)),
245
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(), tz.m.LR(5)),
238
246
  needs_closure=False,
239
247
  func='booth', steps=50, loss=1, merge_invariant=False,
240
248
  sphere_steps=10, sphere_loss=0.1,
241
249
  )
242
250
  NormalizeByEMA_global = Run(
243
- func_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(0.05)),
244
- sphere_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(5)),
251
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(0.05)),
252
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(5)),
245
253
  needs_closure=False,
246
254
  func='booth', steps=50, loss=1, merge_invariant=True,
247
255
  sphere_steps=10, sphere_loss=0.1,
248
256
  )
249
257
  ClipValueByEMA = Run(
250
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValueByEMA(), tz.m.LR(0.1)),
251
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueByEMA(), tz.m.LR(4)),
258
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueByEMA(), tz.m.LR(0.1)),
259
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueByEMA(), tz.m.LR(4)),
252
260
  needs_closure=False,
253
261
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
254
262
  sphere_steps=10, sphere_loss=0.03,
255
263
  )
256
264
  # ------------------------- clipping/growth_clipping ------------------------- #
257
265
  ClipValueGrowth = Run(
258
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
259
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
266
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
267
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
260
268
  needs_closure=False,
261
269
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
262
270
  sphere_steps=10, sphere_loss=100,
263
271
  )
264
272
  ClipValueGrowth_additive = Run(
265
- func_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
266
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
273
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
274
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
267
275
  needs_closure=False,
268
276
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
269
277
  sphere_steps=10, sphere_loss=10,
270
278
  )
271
279
  ClipNormGrowth = Run(
272
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
273
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
280
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
281
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
274
282
  needs_closure=False,
275
283
  func='booth', steps=50, loss=1e-6, merge_invariant=False,
276
284
  sphere_steps=10, sphere_loss=10,
277
285
  )
278
286
  ClipNormGrowth_additive = Run(
279
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
280
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
287
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
288
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
281
289
  needs_closure=False,
282
290
  func='booth', steps=50, loss=1e-6, merge_invariant=False,
283
291
  sphere_steps=10, sphere_loss=10,
284
292
  )
285
293
  ClipNormGrowth_global = Run(
286
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(parameterwise=False), tz.m.LR(0.1)),
287
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(parameterwise=False), tz.m.LR(0.1)),
294
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
295
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
288
296
  needs_closure=False,
289
297
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
290
298
  sphere_steps=10, sphere_loss=10,
@@ -292,43 +300,43 @@ ClipNormGrowth_global = Run(
292
300
 
293
301
  # -------------------------- grad_approximation/fdm -------------------------- #
294
302
  FDM_central2 = Run(
295
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
296
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
303
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
304
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(), tz.m.LR(0.1)),
297
305
  needs_closure=True,
298
306
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
299
307
  sphere_steps=2, sphere_loss=340,
300
308
  )
301
309
  FDM_forward2 = Run(
302
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
303
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
310
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
311
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
304
312
  needs_closure=True,
305
313
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
306
314
  sphere_steps=2, sphere_loss=340,
307
315
  )
308
316
  FDM_backward2 = Run(
309
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
310
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
317
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
318
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
311
319
  needs_closure=True,
312
320
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
313
321
  sphere_steps=2, sphere_loss=340,
314
322
  )
315
323
  FDM_forward3 = Run(
316
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
317
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
324
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
325
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
318
326
  needs_closure=True,
319
327
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
320
328
  sphere_steps=2, sphere_loss=340,
321
329
  )
322
330
  FDM_backward3 = Run(
323
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
324
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
331
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
332
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
325
333
  needs_closure=True,
326
334
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
327
335
  sphere_steps=2, sphere_loss=340,
328
336
  )
329
337
  FDM_central4 = Run(
330
- func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
331
- sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
338
+ func_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
339
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
332
340
  needs_closure=True,
333
341
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
334
342
  sphere_steps=2, sphere_loss=340,
@@ -336,147 +344,147 @@ FDM_central4 = Run(
336
344
 
337
345
  # -------------------------- grad_approximation/rfdm ------------------------- #
338
346
  RandomizedFDM_central2 = Run(
339
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.01)),
340
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.001)),
347
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.01)),
348
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.001)),
341
349
  needs_closure=True,
342
350
  func='booth', steps=50, loss=10, merge_invariant=True,
343
- sphere_steps=100, sphere_loss=450,
351
+ sphere_steps=200, sphere_loss=420,
344
352
  )
345
353
  RandomizedFDM_forward2 = Run(
346
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.01)),
347
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.001)),
354
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.01)),
355
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.001)),
348
356
  needs_closure=True,
349
357
  func='booth', steps=50, loss=10, merge_invariant=True,
350
- sphere_steps=100, sphere_loss=450,
358
+ sphere_steps=200, sphere_loss=420,
351
359
  )
352
360
  RandomizedFDM_backward2 = Run(
353
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.01)),
354
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.001)),
361
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.01)),
362
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.001)),
355
363
  needs_closure=True,
356
364
  func='booth', steps=50, loss=10, merge_invariant=True,
357
- sphere_steps=100, sphere_loss=450,
365
+ sphere_steps=200, sphere_loss=420,
358
366
  )
359
367
  RandomizedFDM_forward3 = Run(
360
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.01)),
361
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.001)),
368
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.01)),
369
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.001)),
362
370
  needs_closure=True,
363
371
  func='booth', steps=50, loss=10, merge_invariant=True,
364
- sphere_steps=100, sphere_loss=450,
372
+ sphere_steps=200, sphere_loss=420,
365
373
  )
366
374
  RandomizedFDM_backward3 = Run(
367
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.01)),
368
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.001)),
375
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.01)),
376
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.001)),
369
377
  needs_closure=True,
370
378
  func='booth', steps=50, loss=10, merge_invariant=True,
371
- sphere_steps=100, sphere_loss=450,
379
+ sphere_steps=200, sphere_loss=420,
372
380
  )
373
381
  RandomizedFDM_central4 = Run(
374
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.01)),
375
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.001)),
382
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.01)),
383
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.001)),
376
384
  needs_closure=True,
377
385
  func='booth', steps=50, loss=10, merge_invariant=True,
378
- sphere_steps=100, sphere_loss=450,
386
+ sphere_steps=200, sphere_loss=420,
379
387
  )
380
388
  RandomizedFDM_forward4 = Run(
381
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
382
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
389
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
390
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
383
391
  needs_closure=True,
384
392
  func='booth', steps=50, loss=10, merge_invariant=True,
385
- sphere_steps=100, sphere_loss=450,
393
+ sphere_steps=200, sphere_loss=420,
386
394
  )
387
395
  RandomizedFDM_forward5 = Run(
388
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
389
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
396
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
397
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
390
398
  needs_closure=True,
391
399
  func='booth', steps=50, loss=10, merge_invariant=True,
392
- sphere_steps=100, sphere_loss=450,
400
+ sphere_steps=200, sphere_loss=420,
393
401
  )
394
402
 
395
403
 
396
404
  RandomizedFDM_4samples = Run(
397
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
398
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.001)),
405
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
406
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.001)),
399
407
  needs_closure=True,
400
408
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
401
409
  sphere_steps=100, sphere_loss=400,
402
410
  )
403
411
  RandomizedFDM_4samples_no_pre_generate = Run(
404
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
405
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
412
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
413
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
406
414
  needs_closure=True,
407
415
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
408
416
  sphere_steps=100, sphere_loss=400,
409
417
  )
410
418
  MeZO = Run(
411
- func_opt=lambda p: tz.Modular(p, tz.m.MeZO(), tz.m.LR(0.01)),
412
- sphere_opt=lambda p: tz.Modular(p, tz.m.MeZO(), tz.m.LR(0.001)),
419
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(), tz.m.LR(0.01)),
420
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(), tz.m.LR(0.001)),
413
421
  needs_closure=True,
414
422
  func='booth', steps=50, loss=5, merge_invariant=True,
415
423
  sphere_steps=100, sphere_loss=450,
416
424
  )
417
425
  MeZO_4samples = Run(
418
- func_opt=lambda p: tz.Modular(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.02)),
419
- sphere_opt=lambda p: tz.Modular(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.005)),
426
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.02)),
427
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.005)),
420
428
  needs_closure=True,
421
429
  func='booth', steps=50, loss=1, merge_invariant=True,
422
430
  sphere_steps=100, sphere_loss=250,
423
431
  )
424
432
  # -------------------- grad_approximation/forward_gradient ------------------- #
425
433
  ForwardGradient = Run(
426
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.01)),
427
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.001)),
434
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.01)),
435
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.001)),
428
436
  needs_closure=True,
429
437
  func='booth', steps=50, loss=40, merge_invariant=True,
430
- sphere_steps=100, sphere_loss=450,
438
+ sphere_steps=200, sphere_loss=450,
431
439
  )
432
440
  ForwardGradient_forward = Run(
433
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.01)),
434
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.001)),
441
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.01)),
442
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.001)),
435
443
  needs_closure=True,
436
444
  func='booth', steps=50, loss=40, merge_invariant=True,
437
- sphere_steps=100, sphere_loss=450,
445
+ sphere_steps=200, sphere_loss=450,
438
446
  )
439
447
  ForwardGradient_central = Run(
440
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.01)),
441
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.001)),
448
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.01)),
449
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.001)),
442
450
  needs_closure=True,
443
451
  func='booth', steps=50, loss=40, merge_invariant=True,
444
- sphere_steps=100, sphere_loss=450,
452
+ sphere_steps=200, sphere_loss=450,
445
453
  )
446
454
  ForwardGradient_4samples = Run(
447
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.1)),
448
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.001)),
455
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.1)),
456
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.001)),
449
457
  needs_closure=True,
450
458
  func='booth', steps=50, loss=0.1, merge_invariant=True,
451
- sphere_steps=100, sphere_loss=400,
459
+ sphere_steps=100, sphere_loss=420,
452
460
  )
453
461
  ForwardGradient_4samples_no_pre_generate = Run(
454
- func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.1)),
455
- sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.001)),
462
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.1)),
463
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.001)),
456
464
  needs_closure=True,
457
465
  func='booth', steps=50, loss=0.1, merge_invariant=True,
458
- sphere_steps=100, sphere_loss=400,
466
+ sphere_steps=100, sphere_loss=420,
459
467
  )
460
468
 
461
469
  # ------------------------- line_search/backtracking ------------------------- #
462
470
  Backtracking = Run(
463
- func_opt=lambda p: tz.Modular(p, tz.m.Backtracking()),
464
- sphere_opt=lambda p: tz.Modular(p, tz.m.Backtracking()),
471
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Backtracking()),
472
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Backtracking()),
465
473
  needs_closure=True,
466
474
  func='booth', steps=50, loss=0, merge_invariant=True,
467
475
  sphere_steps=2, sphere_loss=0,
468
476
  )
469
477
  AdaptiveBacktracking = Run(
470
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
471
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
478
+ func_opt=lambda p: tz.Optimizer(p, tz.m.AdaptiveBacktracking()),
479
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.AdaptiveBacktracking()),
472
480
  needs_closure=True,
473
481
  func='booth', steps=50, loss=1e-11, merge_invariant=True,
474
482
  sphere_steps=2, sphere_loss=1e-10,
475
483
  )
476
484
  # ----------------------------- line_search/scipy ---------------------------- #
477
485
  ScipyMinimizeScalar = Run(
478
- func_opt=lambda p: tz.Modular(p, tz.m.ScipyMinimizeScalar(maxiter=10)),
479
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking(maxiter=10)),
486
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ScipyMinimizeScalar(maxiter=10)),
487
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.AdaptiveBacktracking(maxiter=10)),
480
488
  needs_closure=True,
481
489
  func='booth', steps=50, loss=1e-2, merge_invariant=True,
482
490
  sphere_steps=2, sphere_loss=0,
@@ -484,8 +492,8 @@ ScipyMinimizeScalar = Run(
484
492
 
485
493
  # ------------------------- line_search/strong_wolfe ------------------------- #
486
494
  StrongWolfe = Run(
487
- func_opt=lambda p: tz.Modular(p, tz.m.StrongWolfe()),
488
- sphere_opt=lambda p: tz.Modular(p, tz.m.StrongWolfe()),
495
+ func_opt=lambda p: tz.Optimizer(p, tz.m.StrongWolfe()),
496
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.StrongWolfe()),
489
497
  needs_closure=True,
490
498
  func='booth', steps=50, loss=0, merge_invariant=True,
491
499
  sphere_steps=2, sphere_loss=0,
@@ -493,44 +501,44 @@ StrongWolfe = Run(
493
501
 
494
502
  # ----------------------------------- lr/lr ---------------------------------- #
495
503
  LR = Run(
496
- func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
497
- sphere_opt=lambda p: tz.Modular(p, tz.m.LR(0.5)),
504
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LR(0.1)),
505
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LR(0.5)),
498
506
  needs_closure=False,
499
507
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
500
508
  sphere_steps=10, sphere_loss=0,
501
509
  )
502
510
  StepSize = Run(
503
- func_opt=lambda p: tz.Modular(p, tz.m.StepSize(0.1)),
504
- sphere_opt=lambda p: tz.Modular(p, tz.m.StepSize(0.5)),
511
+ func_opt=lambda p: tz.Optimizer(p, tz.m.StepSize(0.1)),
512
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.StepSize(0.5)),
505
513
  needs_closure=False,
506
514
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
507
515
  sphere_steps=10, sphere_loss=0,
508
516
  )
509
517
  Warmup = Run(
510
- func_opt=lambda p: tz.Modular(p, tz.m.Warmup(steps=50, end_lr=0.1)),
511
- sphere_opt=lambda p: tz.Modular(p, tz.m.Warmup(steps=10)),
518
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Warmup(steps=50, end_lr=0.1)),
519
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Warmup(steps=10)),
512
520
  needs_closure=False,
513
521
  func='booth', steps=50, loss=0.003, merge_invariant=True,
514
522
  sphere_steps=10, sphere_loss=0.05,
515
523
  )
516
524
  # ------------------------------- lr/step_size ------------------------------- #
517
525
  PolyakStepSize = Run(
518
- func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
519
- sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
526
+ func_opt=lambda p: tz.Optimizer(p, tz.m.PolyakStepSize()),
527
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.PolyakStepSize()),
520
528
  needs_closure=True,
521
529
  func='booth', steps=50, loss=1e-7, merge_invariant=True,
522
530
  sphere_steps=10, sphere_loss=0.002,
523
531
  )
524
532
  RandomStepSize = Run(
525
- func_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, seed=0)),
526
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, seed=0)),
533
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, seed=0)),
534
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, seed=0)),
527
535
  needs_closure=False,
528
536
  func='booth', steps=50, loss=0.0005, merge_invariant=True,
529
537
  sphere_steps=10, sphere_loss=100,
530
538
  )
531
539
  RandomStepSize_parameterwise = Run(
532
- func_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
533
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
540
+ func_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
541
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
534
542
  needs_closure=False,
535
543
  func='booth', steps=50, loss=0.0005, merge_invariant=False,
536
544
  sphere_steps=10, sphere_loss=100,
@@ -538,22 +546,22 @@ RandomStepSize_parameterwise = Run(
538
546
 
539
547
  # ---------------------------- momentum/averaging ---------------------------- #
540
548
  Averaging = Run(
541
- func_opt=lambda p: tz.Modular(p, tz.m.Averaging(10), tz.m.LR(0.02)),
542
- sphere_opt=lambda p: tz.Modular(p, tz.m.Averaging(10), tz.m.LR(0.2)),
549
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Averaging(10), tz.m.LR(0.02)),
550
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Averaging(10), tz.m.LR(0.2)),
543
551
  needs_closure=False,
544
552
  func='booth', steps=50, loss=0.5, merge_invariant=True,
545
553
  sphere_steps=10, sphere_loss=0.05,
546
554
  )
547
555
  WeightedAveraging = Run(
548
- func_opt=lambda p: tz.Modular(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.05)),
549
- sphere_opt=lambda p: tz.Modular(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.5)),
556
+ func_opt=lambda p: tz.Optimizer(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.05)),
557
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.5)),
550
558
  needs_closure=False,
551
559
  func='booth', steps=50, loss=1, merge_invariant=True,
552
560
  sphere_steps=10, sphere_loss=2,
553
561
  )
554
562
  MedianAveraging = Run(
555
- func_opt=lambda p: tz.Modular(p, tz.m.MedianAveraging(10), tz.m.LR(0.05)),
556
- sphere_opt=lambda p: tz.Modular(p, tz.m.MedianAveraging(10), tz.m.LR(0.5)),
563
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MedianAveraging(10), tz.m.LR(0.05)),
564
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MedianAveraging(10), tz.m.LR(0.5)),
557
565
  needs_closure=False,
558
566
  func='booth', steps=50, loss=0.005, merge_invariant=True,
559
567
  sphere_steps=10, sphere_loss=0,
@@ -561,36 +569,36 @@ MedianAveraging = Run(
561
569
 
562
570
  # ----------------------------- momentum/cautious ---------------------------- #
563
571
  Cautious = Run(
564
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
565
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
572
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
573
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
566
574
  needs_closure=False,
567
575
  func='booth', steps=50, loss=0.003, merge_invariant=True,
568
576
  sphere_steps=10, sphere_loss=2,
569
577
  )
570
578
  UpdateGradientSignConsistency = Run(
571
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
572
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
579
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
580
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
573
581
  needs_closure=False,
574
582
  func='booth', steps=50, loss=0.003, merge_invariant=True,
575
583
  sphere_steps=10, sphere_loss=2,
576
584
  )
577
585
  IntermoduleCautious = Run(
578
- func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
579
- sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
586
+ func_opt=lambda p: tz.Optimizer(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
587
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
580
588
  needs_closure=False,
581
589
  func='booth', steps=50, loss=1e-4, merge_invariant=True,
582
590
  sphere_steps=10, sphere_loss=0.1,
583
591
  )
584
592
  ScaleByGradCosineSimilarity = Run(
585
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.01)),
586
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.1)),
593
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.01)),
594
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.1)),
587
595
  needs_closure=False,
588
596
  func='booth', steps=50, loss=0.1, merge_invariant=True,
589
597
  sphere_steps=10, sphere_loss=0.1,
590
598
  )
591
599
  ScaleModulesByCosineSimilarity = Run(
592
- func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
593
- sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
600
+ func_opt=lambda p: tz.Optimizer(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
601
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
594
602
  needs_closure=False,
595
603
  func='booth', steps=50, loss=0.005, merge_invariant=True,
596
604
  sphere_steps=10, sphere_loss=0.1,
@@ -598,66 +606,66 @@ ScaleModulesByCosineSimilarity = Run(
598
606
 
599
607
  # ------------------------- momentum/matrix_momentum ------------------------- #
600
608
  MatrixMomentum_forward = Run(
601
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='forward'),),
602
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward')),
609
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_forward'),),
610
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward')),
603
611
  needs_closure=True,
604
612
  func='booth', steps=50, loss=0.05, merge_invariant=True,
605
613
  sphere_steps=10, sphere_loss=0.01,
606
614
  )
607
615
  MatrixMomentum_forward = Run(
608
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='central')),
609
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central')),
616
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_central')),
617
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central')),
610
618
  needs_closure=True,
611
619
  func='booth', steps=50, loss=0.05, merge_invariant=True,
612
620
  sphere_steps=10, sphere_loss=0.01,
613
621
  )
614
622
  MatrixMomentum_forward = Run(
615
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
616
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
623
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
624
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
617
625
  needs_closure=True,
618
626
  func='booth', steps=50, loss=0.05, merge_invariant=True,
619
627
  sphere_steps=10, sphere_loss=0.01,
620
628
  )
621
629
 
622
630
  AdaptiveMatrixMomentum_forward = Run(
623
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True)),
624
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True)),
631
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True)),
632
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True)),
625
633
  needs_closure=True,
626
634
  func='booth', steps=50, loss=0.05, merge_invariant=True,
627
635
  sphere_steps=10, sphere_loss=0.05,
628
636
  )
629
637
  AdaptiveMatrixMomentum_central = Run(
630
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True)),
631
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True)),
638
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True)),
639
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True)),
632
640
  needs_closure=True,
633
641
  func='booth', steps=50, loss=0.05, merge_invariant=True,
634
642
  sphere_steps=10, sphere_loss=0.05,
635
643
  )
636
644
  AdaptiveMatrixMomentum_autograd = Run(
637
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
638
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
645
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
646
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
639
647
  needs_closure=True,
640
648
  func='booth', steps=50, loss=0.05, merge_invariant=True,
641
649
  sphere_steps=10, sphere_loss=0.05,
642
650
  )
643
651
 
644
652
  StochasticAdaptiveMatrixMomentum_forward = Run(
645
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True, adapt_freq=1)),
646
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True, adapt_freq=1)),
653
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
654
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
647
655
  needs_closure=True,
648
656
  func='booth', steps=50, loss=0.05, merge_invariant=True,
649
657
  sphere_steps=10, sphere_loss=0.05,
650
658
  )
651
659
  StochasticAdaptiveMatrixMomentum_central = Run(
652
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True, adapt_freq=1)),
653
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True, adapt_freq=1)),
660
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
661
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
654
662
  needs_closure=True,
655
663
  func='booth', steps=50, loss=0.05, merge_invariant=True,
656
664
  sphere_steps=10, sphere_loss=0.05,
657
665
  )
658
666
  StochasticAdaptiveMatrixMomentum_autograd = Run(
659
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
660
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
667
+ func_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
668
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
661
669
  needs_closure=True,
662
670
  func='booth', steps=50, loss=0.05, merge_invariant=True,
663
671
  sphere_steps=10, sphere_loss=0.05,
@@ -666,44 +674,44 @@ StochasticAdaptiveMatrixMomentum_autograd = Run(
666
674
  # EMA, momentum are covered by test_identical
667
675
  # --------------------------------- ops/misc --------------------------------- #
668
676
  Previous = Run(
669
- func_opt=lambda p: tz.Modular(p, tz.m.Previous(10), tz.m.LR(0.05)),
670
- sphere_opt=lambda p: tz.Modular(p, tz.m.Previous(3), tz.m.LR(0.5)),
677
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Previous(10), tz.m.LR(0.05)),
678
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Previous(3), tz.m.LR(0.5)),
671
679
  needs_closure=False,
672
680
  func='booth', steps=50, loss=15, merge_invariant=True,
673
681
  sphere_steps=10, sphere_loss=0,
674
682
  )
675
683
  GradSign = Run(
676
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.05)),
677
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.5)),
684
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.05)),
685
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.5)),
678
686
  needs_closure=False,
679
687
  func='booth', steps=50, loss=0.0002, merge_invariant=True,
680
688
  sphere_steps=10, sphere_loss=0.1,
681
689
  )
682
690
  UpdateSign = Run(
683
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.05)),
684
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.5)),
691
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.05)),
692
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.5)),
685
693
  needs_closure=False,
686
694
  func='booth', steps=50, loss=0.01, merge_invariant=True,
687
695
  sphere_steps=10, sphere_loss=0,
688
696
  )
689
697
  GradAccumulation = Run(
690
- func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
691
- sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
698
+ func_opt=lambda p: tz.Optimizer(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
699
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
692
700
  needs_closure=False,
693
701
  func='booth', steps=50, loss=25, merge_invariant=True,
694
702
  sphere_steps=20, sphere_loss=1e-11,
695
703
  )
696
704
  NegateOnLossIncrease = Run(
697
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
698
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
705
+ func_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
706
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
699
707
  needs_closure=True,
700
708
  func='booth', steps=50, loss=0.1, merge_invariant=True,
701
709
  sphere_steps=20, sphere_loss=0.001,
702
710
  )
703
711
  # -------------------------------- misc/switch ------------------------------- #
704
712
  Alternate = Run(
705
- func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
706
- sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
713
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
714
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
707
715
  needs_closure=False,
708
716
  func='booth', steps=50, loss=1, merge_invariant=True,
709
717
  sphere_steps=20, sphere_loss=20,
@@ -711,65 +719,68 @@ Alternate = Run(
711
719
 
712
720
  # ------------------------------ optimizers/adam ----------------------------- #
713
721
  Adam = Run(
714
- func_opt=lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.5)),
715
- sphere_opt=lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.2)),
722
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Adam(), tz.m.LR(0.5)),
723
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Adam(), tz.m.LR(0.2)),
716
724
  needs_closure=False,
717
725
  func='rosen', steps=50, loss=4, merge_invariant=True,
718
726
  sphere_steps=20, sphere_loss=4,
719
727
  )
720
728
  # ------------------------------ optimizers/soap ----------------------------- #
721
729
  SOAP = Run(
722
- func_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(0.4)),
723
- sphere_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(1)),
730
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(0.4)),
731
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
724
732
  needs_closure=False,
733
+ # merge and unmerge lrs are very different so need to test convergence separately somewhere
725
734
  func='rosen', steps=50, loss=4, merge_invariant=False,
726
- sphere_steps=20, sphere_loss=25, # merge and unmerge lrs are very different so need to test convergence separately somewhere
735
+ sphere_steps=20, sphere_loss=25,
727
736
  )
728
737
  # ------------------------------ optimizers/lion ----------------------------- #
729
738
  Lion = Run(
730
- func_opt=lambda p: tz.Modular(p, tz.m.Lion(), tz.m.LR(1)),
731
- sphere_opt=lambda p: tz.Modular(p, tz.m.Lion(), tz.m.LR(0.1)),
739
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Lion(), tz.m.LR(1)),
740
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Lion(), tz.m.LR(0.1)),
732
741
  needs_closure=False,
733
742
  func='booth', steps=50, loss=0, merge_invariant=True,
734
743
  sphere_steps=20, sphere_loss=25,
735
744
  )
736
745
  # ---------------------------- optimizers/shampoo ---------------------------- #
737
746
  Shampoo = Run(
738
- func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
739
- sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
747
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
748
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
740
749
  needs_closure=False,
750
+ # merge and unmerge lrs are very different so need to test convergence separately somewhere
741
751
  func='booth', steps=50, loss=0.02, merge_invariant=False,
742
- sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
752
+ sphere_steps=20, sphere_loss=1,
743
753
  )
744
754
 
745
755
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
746
756
  BFGS = Run(
747
- func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
748
- sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
757
+ func_opt=lambda p: tz.Optimizer(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
758
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
749
759
  needs_closure=True,
750
760
  func='rosen', steps=50, loss=1e-10, merge_invariant=True,
751
761
  sphere_steps=10, sphere_loss=1e-10,
752
762
  )
753
763
  SR1 = Run(
754
- func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_restart=True, scale_first=True), tz.m.StrongWolfe(fallback=False)),
755
- sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
764
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SR1(ptol_restart=True), tz.m.StrongWolfe(c2=0.1)),
765
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(c2=0.1)),
756
766
  needs_closure=True,
757
767
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
758
768
  # this reaches 1e-13 on github so don't change to 0
759
769
  sphere_steps=10, sphere_loss=0,
760
770
  )
761
771
  SSVM = Run(
762
- func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
763
- sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
772
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SSVM(1), tz.m.StrongWolfe(fallback=True)),
773
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SSVM(1), tz.m.StrongWolfe(fallback=True)),
764
774
  needs_closure=True,
775
+ # this reaches 0.12 on github so don't change to 0.002
765
776
  func='rosen', steps=50, loss=0.2, merge_invariant=True,
766
777
  sphere_steps=10, sphere_loss=0,
767
778
  )
768
779
 
769
780
  # ---------------------------- quasi_newton/lbfgs ---------------------------- #
770
781
  LBFGS = Run(
771
- func_opt=lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
772
- sphere_opt=lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
782
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
783
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
773
784
  needs_closure=True,
774
785
  func='rosen', steps=50, loss=0, merge_invariant=True,
775
786
  sphere_steps=10, sphere_loss=0,
@@ -777,8 +788,8 @@ LBFGS = Run(
777
788
 
778
789
  # ----------------------------- quasi_newton/lsr1 ---------------------------- #
779
790
  LSR1 = Run(
780
- func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
781
- sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
791
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
792
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
782
793
  needs_closure=True,
783
794
  func='rosen', steps=50, loss=0, merge_invariant=True,
784
795
  sphere_steps=10, sphere_loss=0,
@@ -786,8 +797,8 @@ LSR1 = Run(
786
797
 
787
798
  # # ---------------------------- quasi_newton/olbfgs --------------------------- #
788
799
  # OnlineLBFGS = Run(
789
- # func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
790
- # sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
800
+ # func_opt=lambda p: tz.Optimizer(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
801
+ # sphere_opt=lambda p: tz.Optimizer(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
791
802
  # needs_closure=True,
792
803
  # func='rosen', steps=50, loss=0, merge_invariant=True,
793
804
  # sphere_steps=10, sphere_loss=0,
@@ -795,8 +806,8 @@ LSR1 = Run(
795
806
 
796
807
  # ---------------------------- second_order/newton --------------------------- #
797
808
  Newton = Run(
798
- func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
799
- sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
809
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
810
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
800
811
  needs_closure=True,
801
812
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
802
813
  sphere_steps=2, sphere_loss=1e-9,
@@ -804,8 +815,8 @@ Newton = Run(
804
815
 
805
816
  # --------------------------- second_order/newton_cg -------------------------- #
806
817
  NewtonCG = Run(
807
- func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
808
- sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
818
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
819
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
809
820
  needs_closure=True,
810
821
  func='rosen', steps=20, loss=1e-10, merge_invariant=True,
811
822
  sphere_steps=2, sphere_loss=3e-4,
@@ -813,8 +824,8 @@ NewtonCG = Run(
813
824
 
814
825
  # ---------------------------- smoothing/gaussian ---------------------------- #
815
826
  GaussianHomotopy = Run(
816
- func_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
817
- sphere_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
827
+ func_opt=lambda p: tz.Optimizer(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
828
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
818
829
  needs_closure=True,
819
830
  func='booth', steps=20, loss=0.01, merge_invariant=True,
820
831
  sphere_steps=10, sphere_loss=1,
@@ -822,16 +833,16 @@ GaussianHomotopy = Run(
822
833
 
823
834
  # ---------------------------- smoothing/laplacian --------------------------- #
824
835
  LaplacianSmoothing = Run(
825
- func_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.1)),
826
- sphere_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.5)),
836
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.1)),
837
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.5)),
827
838
  needs_closure=False,
828
839
  func='booth', steps=50, loss=0.4, merge_invariant=False,
829
840
  sphere_steps=10, sphere_loss=3,
830
841
  )
831
842
 
832
843
  LaplacianSmoothing_global = Run(
833
- func_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.1)),
834
- sphere_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.5)),
844
+ func_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.1)),
845
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.5)),
835
846
  needs_closure=False,
836
847
  func='booth', steps=50, loss=0.4, merge_invariant=True,
837
848
  sphere_steps=10, sphere_loss=3,
@@ -839,8 +850,8 @@ LaplacianSmoothing_global = Run(
839
850
 
840
851
  # -------------------------- wrappers/optim_wrapper -------------------------- #
841
852
  Wrap = Run(
842
- func_opt=lambda p: tz.Modular(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.5)),
843
- sphere_opt=lambda p: tz.Modular(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.2)),
853
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.5)),
854
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.2)),
844
855
  needs_closure=False,
845
856
  func='rosen', steps=50, loss=4, merge_invariant=True,
846
857
  sphere_steps=20, sphere_loss=4,
@@ -848,15 +859,15 @@ Wrap = Run(
848
859
 
849
860
  # --------------------------- second_order/nystrom --------------------------- #
850
861
  NystromSketchAndSolve = Run(
851
- func_opt=lambda p: tz.Modular(p, tz.m.NystromSketchAndSolve(2, seed=0), tz.m.StrongWolfe()),
852
- sphere_opt=lambda p: tz.Modular(p, tz.m.NystromSketchAndSolve(10, seed=0), tz.m.StrongWolfe()),
862
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NystromSketchAndSolve(2, seed=0), tz.m.StrongWolfe()),
863
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NystromSketchAndSolve(10, seed=0), tz.m.StrongWolfe()),
853
864
  needs_closure=True,
854
865
  func='booth', steps=3, loss=1e-6, merge_invariant=True,
855
866
  sphere_steps=10, sphere_loss=1e-12,
856
867
  )
857
868
  NystromPCG = Run(
858
- func_opt=lambda p: tz.Modular(p, tz.m.NystromPCG(2, seed=0), tz.m.StrongWolfe()),
859
- sphere_opt=lambda p: tz.Modular(p, tz.m.NystromPCG(10, seed=0), tz.m.StrongWolfe()),
869
+ func_opt=lambda p: tz.Optimizer(p, tz.m.NystromPCG(2, seed=0), tz.m.StrongWolfe()),
870
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.NystromPCG(10, seed=0), tz.m.StrongWolfe()),
860
871
  needs_closure=True,
861
872
  func='ill', steps=2, loss=1e-5, merge_invariant=True,
862
873
  sphere_steps=2, sphere_loss=1e-9,
@@ -864,8 +875,8 @@ NystromPCG = Run(
864
875
 
865
876
  # ---------------------------- optimizers/sophia_h --------------------------- #
866
877
  SophiaH = Run(
867
- func_opt=lambda p: tz.Modular(p, tz.m.SophiaH(seed=0), tz.m.LR(0.1)),
868
- sphere_opt=lambda p: tz.Modular(p, tz.m.SophiaH(seed=0), tz.m.LR(0.3)),
878
+ func_opt=lambda p: tz.Optimizer(p, tz.m.SophiaH(seed=0), tz.m.LR(0.1)),
879
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.SophiaH(seed=0), tz.m.LR(0.3)),
869
880
  needs_closure=True,
870
881
  func='ill', steps=50, loss=0.02, merge_invariant=True,
871
882
  sphere_steps=10, sphere_loss=40,
@@ -873,17 +884,17 @@ SophiaH = Run(
873
884
 
874
885
  # -------------------------- higher_order ------------------------- #
875
886
  HigherOrderNewton = Run(
876
- func_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(trust_method=None)),
877
- sphere_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(2, trust_method=None)),
887
+ func_opt=lambda p: tz.Optimizer(p, tz.m.experimental.HigherOrderNewton(trust_method=None)),
888
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.experimental.HigherOrderNewton(2, trust_method=None)),
878
889
  needs_closure=True,
879
890
  func='rosen', steps=1, loss=2e-10, merge_invariant=True,
880
891
  sphere_steps=1, sphere_loss=1e-10,
881
892
  )
882
893
 
883
894
  # ---------------------------- optimizers/ladagrad --------------------------- #
884
- LMAdagrad = Run(
885
- func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
886
- sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
895
+ GGT = Run(
896
+ func_opt=lambda p: tz.Optimizer(p, tz.m.GGT(), tz.m.LR(4)),
897
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.GGT(), tz.m.LR(5)),
887
898
  needs_closure=False,
888
899
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
889
900
  sphere_steps=20, sphere_loss=1e-9,
@@ -891,8 +902,8 @@ LMAdagrad = Run(
891
902
 
892
903
  # ------------------------------ optimizers/adan ----------------------------- #
893
904
  Adan = Run(
894
- func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
895
- sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
905
+ func_opt=lambda p: tz.Optimizer(p, tz.m.Adan(), tz.m.LR(1)),
906
+ sphere_opt=lambda p: tz.Optimizer(p, tz.m.Adan(), tz.m.LR(0.1)),
896
907
  needs_closure=False,
897
908
  func='booth', steps=50, loss=60, merge_invariant=True,
898
909
  sphere_steps=20, sphere_loss=60,
@@ -903,8 +914,8 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
903
914
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
904
915
  # but also test 10 to make sure it doesn't explode after converging
905
916
  Run(
906
- func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
907
- sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
917
+ func_opt=lambda p: tz.Optimizer(p, CG(), tz.m.StrongWolfe(c2=0.1)),
918
+ sphere_opt=lambda p: tz.Optimizer(p, CG(), tz.m.StrongWolfe(c2=0.1)),
908
919
  needs_closure=True,
909
920
  func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=True,
910
921
  sphere_steps=sphere_steps_, sphere_loss=0,
@@ -937,8 +948,8 @@ for QN in (
937
948
  tz.m.SSVM,
938
949
  ):
939
950
  Run(
940
- func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
941
- sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
951
+ func_opt=lambda p: tz.Optimizer(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
952
+ sphere_opt=lambda p: tz.Optimizer(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
942
953
  needs_closure=True,
943
954
  func='lstsq', steps=50, loss=1e-10, merge_invariant=True,
944
955
  sphere_steps=10, sphere_loss=1e-20,