torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
tests/test_opts.py ADDED
@@ -0,0 +1,884 @@
1
+ """snity tests to make sure everything works and converges on basic functions"""
2
+ from collections.abc import Callable
3
+ from functools import partial
4
+
5
+ import pytest
6
+ import torch
7
+ import torchzero as tz
8
+
9
+ PRINT = False # set to true in nbs
10
+
11
+ def _booth(x, y):
12
+ return (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
13
+
14
+ def _rosen(x, y):
15
+ return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
16
+
17
+ def _ill(x, y):
18
+ return x**2 + y**2 + 1.99999*x*y
19
+
20
+ def _lstsq(x,y): # specifically for CG and quasi newton methods, staircase effect is more pronounced there
21
+ return (2*x + 3*y - 5)**2 + (5*x - 2*y - 3)**2
22
+
23
+ funcs = {"booth": (_booth, (0, -8)), "rosen": (_rosen, (-1.1, 2.5)), "ill": (_ill, (-9, 2.5)), "lstsq": (_lstsq, (-0.9, 0))}
24
+ """{"name": (function, x0)}"""
25
+
26
+ class _TestModel(torch.nn.Module):
27
+ """sphere with all kinds of parameter shapes, initial loss is 521.2754"""
28
+ def __init__(self):
29
+ super().__init__()
30
+ generator = torch.Generator().manual_seed(0)
31
+ randn = partial(torch.randn, generator=generator)
32
+ params = [
33
+ torch.tensor(1.), torch.tensor([1.]), torch.tensor([[1.]]),
34
+ randn(10), randn(1,10), randn(10,1), randn(1,1,10),randn(1,10,1),randn(1,1,10),
35
+ randn(10,10), randn(4,4,4), randn(3,3,3,3), randn(2,2,2,2,2,2,2),
36
+ randn(10,1,3,1,1),
37
+ torch.zeros(2,2), torch.ones(2,2),
38
+ ]
39
+ self.params = torch.nn.ParameterList(torch.nn.Parameter(t) for t in params)
40
+
41
+ def forward(self):
42
+ return torch.sum(torch.stack([p.pow(2).sum() for p in self.params]))
43
+
44
+ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, steps: int, clear: bool):
45
+ """generic function to run opt on objective and return lowest recorded loss"""
46
+ losses = []
47
+ for i in range(steps):
48
+ if clear and i == steps//2:
49
+ for m in opt.unrolled_modules: m.reset() # clear on middle step to see if there are any issues with it
50
+
51
+ if use_closure:
52
+ def closure(backward=True):
53
+ loss = objective()
54
+ if backward:
55
+ opt.zero_grad()
56
+ loss.backward()
57
+ return loss
58
+ loss = opt.step(closure)
59
+ assert loss is not None
60
+ assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
61
+ losses.append(loss)
62
+
63
+ else:
64
+ loss = objective()
65
+ opt.zero_grad()
66
+ loss.backward()
67
+ opt.step()
68
+ assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
69
+ losses.append(loss)
70
+
71
+ return torch.stack(losses).nan_to_num(0,10000,10000).min()
72
+
73
+ def _run_func(opt_fn: Callable, func:str, merge: bool, use_closure: bool, steps: int):
74
+ """run optimizer on a test function and return lowest loss"""
75
+ fn, x0 = funcs[func]
76
+ X = torch.tensor(x0, dtype=torch.float32, requires_grad=True)
77
+ if merge:
78
+ opt = opt_fn([X])
79
+ else:
80
+ x,y = [i.clone().detach().requires_grad_() for i in X]
81
+ X = (x,y)
82
+ opt = opt_fn(X)
83
+
84
+ def objective():
85
+ return fn(*X)
86
+
87
+ return _run_objective(opt, objective, use_closure, steps, clear=False), opt
88
+
89
+ def _run_sphere(opt_fn: Callable, use_closure:bool, steps:int):
90
+ """run optimizer on sphere test module to test different parameter shapes (common cause of mistakes)"""
91
+ sphere = _TestModel()
92
+ opt = opt_fn(sphere.parameters())
93
+ return _run_objective(opt, sphere, use_closure, steps, clear=True), opt
94
+
95
+ def _run(func_opt: Callable, sphere_opt: Callable, needs_closure: bool, func:str, steps: int, loss: float, merge_invariant: bool, sphere_steps: int, sphere_loss: float):
96
+ """Run optimizer on function and sphere test module and check that loss is low enough"""
97
+ tested_sphere = {True: False, False: False} # because shere has no merge
98
+ merged_losses = []
99
+ unmerged_losses = []
100
+ sphere_losses = []
101
+
102
+ for merge in [True, False]:
103
+ for use_closure in [True] if needs_closure else [True, False]:
104
+ if PRINT: print(f"testing with {merge = }, {use_closure = }")
105
+ v,opt = _run_func(func_opt, func, merge, use_closure, steps)
106
+ if PRINT: print(f'{func} loss after {steps} steps is {v}, target is {loss}')
107
+ assert v <= loss, f"{opt}: Loss on {func} is {v}, which is above target {loss}. {merge = }, {use_closure = }"
108
+ if merge: merged_losses.append(v)
109
+ else: unmerged_losses.append(v)
110
+
111
+ if not tested_sphere[use_closure]:
112
+ tested_sphere[use_closure] = True
113
+ v,opt = _run_sphere(sphere_opt, use_closure, sphere_steps)
114
+ if PRINT: print(f'sphere loss after {sphere_steps} is {v}, target is {sphere_loss}')
115
+ assert v <= sphere_loss, f"{opt}: Loss on sphere is {v}, which is above target {sphere_loss}. {merge = }, {use_closure = }"
116
+ sphere_losses.append(v)
117
+ if PRINT: print()
118
+
119
+ # test if losses match
120
+ if merge_invariant: losses = merged_losses + unmerged_losses
121
+ else: losses = merged_losses
122
+ l = losses[0]
123
+ assert all(i == l for i in losses), f"{func} losses don't match: {[l.item() for l in losses]}"
124
+
125
+ l = unmerged_losses[0]
126
+ assert all(i == l for i in unmerged_losses), f"Sphere losses don't match: {[l.item() for l in unmerged_losses]}"
127
+
128
+
129
+ l = sphere_losses[0]
130
+ assert all(i == l for i in sphere_losses), f"Sphere losses don't match: {[l.item() for l in sphere_losses]}"
131
+
132
+ RUNS = []
133
+ """Whenever a Run is created (__init__ is called) it gets appened to this"""
134
+
135
+ class Run:
136
+ """
137
+ Holds arguments for a test.
138
+
139
+ Args:
140
+ func_opt (Callable): opt for test function e.g. :code:`lambda p: tz.Modular(p, tz.m.Adam())`
141
+ sphere_opt (Callable): opt for sphere e.g. :code:`lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.1))`
142
+ needs_closure (bool): set to True if opt_fn requires closure
143
+ func (str): what test function to use ("booth", "rosen", "ill")
144
+ steps (int): number of steps to run test function for.
145
+ loss (float): if minimal loss is higher than this then test fails
146
+ merge_invariant (bool): whether the optimizer is invariant to parameters merged or separated.
147
+ sphere_steps (int): how many steps to run sphere for (it has like 1000 params)
148
+ sphere_loss (float): if minimal loss is higher than this then test fails
149
+ """
150
+ def __init__(self, func_opt: Callable, sphere_opt: Callable, needs_closure: bool, func: str, steps: int, loss:float, merge_invariant: bool, sphere_steps:int, sphere_loss:float):
151
+ self.kwargs = locals().copy()
152
+ del self.kwargs['self']
153
+ RUNS.append(self)
154
+ def test(self): _run(**self.kwargs)
155
+
156
+ # target losses for all of those are set to just above what they reach
157
+ # ---------------------------------------------------------------------------- #
158
+ # tests #
159
+ # ---------------------------------------------------------------------------- #
160
+ # ----------------------------- clipping/clipping ---------------------------- #
161
+ ClipValue = Run(
162
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipValue(1), tz.m.LR(1)),
163
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValue(1), tz.m.LR(1)),
164
+ needs_closure=False,
165
+ func='booth', steps=50, loss=0, merge_invariant=True,
166
+ sphere_steps=10, sphere_loss=50,
167
+ )
168
+ ClipNorm = Run(
169
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1), tz.m.LR(1)),
170
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1), tz.m.LR(0.5)),
171
+ needs_closure=False,
172
+ func='booth', steps=50, loss=2, merge_invariant=False,
173
+ sphere_steps=10, sphere_loss=0,
174
+ )
175
+ ClipNorm_global = Run(
176
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(1)),
177
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNorm(1, dim='global'), tz.m.LR(3)),
178
+ needs_closure=False,
179
+ func='booth', steps=50, loss=2, merge_invariant=True,
180
+ sphere_steps=10, sphere_loss=2,
181
+ )
182
+ Normalize = Run(
183
+ func_opt=lambda p: tz.Modular(p, tz.m.Normalize(1), tz.m.LR(1)),
184
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Normalize(1), tz.m.LR(0.5)),
185
+ needs_closure=False,
186
+ func='booth', steps=50, loss=2, merge_invariant=False,
187
+ sphere_steps=10, sphere_loss=15,
188
+ )
189
+ Normalize_global = Run(
190
+ func_opt=lambda p: tz.Modular(p, tz.m.Normalize(1, dim='global'), tz.m.LR(1)),
191
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Normalize(1, dim='global'), tz.m.LR(4)),
192
+ needs_closure=False,
193
+ func='booth', steps=50, loss=2, merge_invariant=True,
194
+ sphere_steps=10, sphere_loss=2,
195
+ )
196
+ Centralize = Run(
197
+ func_opt=lambda p: tz.Modular(p, tz.m.Centralize(min_size=3), tz.m.LR(0.1)),
198
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Centralize(), tz.m.LR(0.1)),
199
+ needs_closure=False,
200
+ func='booth', steps=50, loss=1e-6, merge_invariant=False,
201
+ sphere_steps=10, sphere_loss=10,
202
+ )
203
+ Centralize_global = Run(
204
+ func_opt=lambda p: tz.Modular(p, tz.m.Centralize(min_size=3, dim='global'), tz.m.LR(0.1)),
205
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Centralize(dim='global'), tz.m.LR(0.1)),
206
+ needs_closure=False,
207
+ func='booth', steps=1, loss=1000, merge_invariant=True,
208
+ sphere_steps=10, sphere_loss=10,
209
+ )
210
+
211
+ # --------------------------- clipping/ema_clipping -------------------------- #
212
+ ClipNormByEMA = Run(
213
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(), tz.m.LR(0.1)),
214
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(), tz.m.LR(5)),
215
+ needs_closure=False,
216
+ func='booth', steps=50, loss=1e-5, merge_invariant=False,
217
+ sphere_steps=10, sphere_loss=0.1,
218
+ )
219
+ ClipNormByEMA_global = Run(
220
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(0.1)),
221
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormByEMA(tensorwise=False), tz.m.LR(5)),
222
+ needs_closure=False,
223
+ func='booth', steps=50, loss=1e-5, merge_invariant=True,
224
+ sphere_steps=10, sphere_loss=0.1,
225
+ )
226
+ NormalizeByEMA = Run(
227
+ func_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(), tz.m.LR(0.05)),
228
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(), tz.m.LR(5)),
229
+ needs_closure=False,
230
+ func='booth', steps=50, loss=1, merge_invariant=False,
231
+ sphere_steps=10, sphere_loss=0.1,
232
+ )
233
+ NormalizeByEMA_global = Run(
234
+ func_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(0.05)),
235
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NormalizeByEMA(tensorwise=False), tz.m.LR(5)),
236
+ needs_closure=False,
237
+ func='booth', steps=50, loss=1, merge_invariant=True,
238
+ sphere_steps=10, sphere_loss=0.1,
239
+ )
240
+ ClipValueByEMA = Run(
241
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipValueByEMA(), tz.m.LR(0.1)),
242
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueByEMA(), tz.m.LR(4)),
243
+ needs_closure=False,
244
+ func='booth', steps=50, loss=1e-5, merge_invariant=True,
245
+ sphere_steps=10, sphere_loss=0.03,
246
+ )
247
+ # ------------------------- clipping/growth_clipping ------------------------- #
248
+ ClipValueGrowth = Run(
249
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
250
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(), tz.m.LR(0.1)),
251
+ needs_closure=False,
252
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
253
+ sphere_steps=10, sphere_loss=100,
254
+ )
255
+ ClipValueGrowth_additive = Run(
256
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
257
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipValueGrowth(add=1, mul=None), tz.m.LR(0.1)),
258
+ needs_closure=False,
259
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
260
+ sphere_steps=10, sphere_loss=10,
261
+ )
262
+ ClipNormGrowth = Run(
263
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
264
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(), tz.m.LR(0.1)),
265
+ needs_closure=False,
266
+ func='booth', steps=50, loss=1e-6, merge_invariant=False,
267
+ sphere_steps=10, sphere_loss=10,
268
+ )
269
+ ClipNormGrowth_additive = Run(
270
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
271
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(add=1,mul=None), tz.m.LR(0.1)),
272
+ needs_closure=False,
273
+ func='booth', steps=50, loss=1e-6, merge_invariant=False,
274
+ sphere_steps=10, sphere_loss=10,
275
+ )
276
+ ClipNormGrowth_global = Run(
277
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(parameterwise=False), tz.m.LR(0.1)),
278
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(parameterwise=False), tz.m.LR(0.1)),
279
+ needs_closure=False,
280
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
281
+ sphere_steps=10, sphere_loss=10,
282
+ )
283
+
284
+ # -------------------------- grad_approximation/fdm -------------------------- #
285
+ FDM_central2 = Run(
286
+ func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
287
+ sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
288
+ needs_closure=True,
289
+ func='booth', steps=50, loss=1e-7, merge_invariant=True,
290
+ sphere_steps=2, sphere_loss=340,
291
+ )
292
+ FDM_forward2 = Run(
293
+ func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
294
+ sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
295
+ needs_closure=True,
296
+ func='booth', steps=50, loss=1e-7, merge_invariant=True,
297
+ sphere_steps=2, sphere_loss=340,
298
+ )
299
+ FDM_backward2 = Run(
300
+ func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
301
+ sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
302
+ needs_closure=True,
303
+ func='booth', steps=50, loss=2e-7, merge_invariant=True,
304
+ sphere_steps=2, sphere_loss=340,
305
+ )
306
+ FDM_forward3 = Run(
307
+ func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
308
+ sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
309
+ needs_closure=True,
310
+ func='booth', steps=50, loss=3e-7, merge_invariant=True,
311
+ sphere_steps=2, sphere_loss=340,
312
+ )
313
+ FDM_backward3 = Run(
314
+ func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
315
+ sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
316
+ needs_closure=True,
317
+ func='booth', steps=50, loss=3e-7, merge_invariant=True,
318
+ sphere_steps=2, sphere_loss=340,
319
+ )
320
+ FDM_central4 = Run(
321
+ func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
322
+ sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
323
+ needs_closure=True,
324
+ func='booth', steps=50, loss=2e-8, merge_invariant=True,
325
+ sphere_steps=2, sphere_loss=340,
326
+ )
327
+
328
+ # -------------------------- grad_approximation/rfdm ------------------------- #
329
+ RandomizedFDM_central2 = Run(
330
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.01)),
331
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.001)),
332
+ needs_closure=True,
333
+ func='booth', steps=50, loss=10, merge_invariant=True,
334
+ sphere_steps=100, sphere_loss=450,
335
+ )
336
+ RandomizedFDM_forward2 = Run(
337
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.01)),
338
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.001)),
339
+ needs_closure=True,
340
+ func='booth', steps=50, loss=10, merge_invariant=True,
341
+ sphere_steps=100, sphere_loss=450,
342
+ )
343
+ RandomizedFDM_backward2 = Run(
344
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.01)),
345
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.001)),
346
+ needs_closure=True,
347
+ func='booth', steps=50, loss=10, merge_invariant=True,
348
+ sphere_steps=100, sphere_loss=450,
349
+ )
350
+ RandomizedFDM_forward3 = Run(
351
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.01)),
352
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.001)),
353
+ needs_closure=True,
354
+ func='booth', steps=50, loss=10, merge_invariant=True,
355
+ sphere_steps=100, sphere_loss=450,
356
+ )
357
+ RandomizedFDM_backward3 = Run(
358
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.01)),
359
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.001)),
360
+ needs_closure=True,
361
+ func='booth', steps=50, loss=10, merge_invariant=True,
362
+ sphere_steps=100, sphere_loss=450,
363
+ )
364
+ RandomizedFDM_central4 = Run(
365
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.01)),
366
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.001)),
367
+ needs_closure=True,
368
+ func='booth', steps=50, loss=10, merge_invariant=True,
369
+ sphere_steps=100, sphere_loss=450,
370
+ )
371
+
372
+ RandomizedFDM_4samples = Run(
373
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
374
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.001)),
375
+ needs_closure=True,
376
+ func='booth', steps=50, loss=1e-5, merge_invariant=True,
377
+ sphere_steps=100, sphere_loss=400,
378
+ )
379
+ RandomizedFDM_4samples_lerp = Run(
380
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.99, seed=0), tz.m.LR(0.1)),
381
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.9, seed=0), tz.m.LR(0.001)),
382
+ needs_closure=True,
383
+ func='booth', steps=50, loss=1e-5, merge_invariant=True,
384
+ sphere_steps=100, sphere_loss=505,
385
+ )
386
+ RandomizedFDM_4samples_no_pre_generate = Run(
387
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
388
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
389
+ needs_closure=True,
390
+ func='booth', steps=50, loss=1e-5, merge_invariant=True,
391
+ sphere_steps=100, sphere_loss=400,
392
+ )
393
+ MeZO = Run(
394
+ func_opt=lambda p: tz.Modular(p, tz.m.MeZO(), tz.m.LR(0.01)),
395
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MeZO(), tz.m.LR(0.001)),
396
+ needs_closure=True,
397
+ func='booth', steps=50, loss=5, merge_invariant=True,
398
+ sphere_steps=100, sphere_loss=450,
399
+ )
400
+ MeZO_4samples = Run(
401
+ func_opt=lambda p: tz.Modular(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.02)),
402
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MeZO(n_samples=4), tz.m.LR(0.005)),
403
+ needs_closure=True,
404
+ func='booth', steps=50, loss=1, merge_invariant=True,
405
+ sphere_steps=100, sphere_loss=250,
406
+ )
407
+ # -------------------- grad_approximation/forward_gradient ------------------- #
408
+ ForwardGradient = Run(
409
+ func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.01)),
410
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.001)),
411
+ needs_closure=True,
412
+ func='booth', steps=50, loss=40, merge_invariant=True,
413
+ sphere_steps=100, sphere_loss=450,
414
+ )
415
+ ForwardGradient_forward = Run(
416
+ func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.01)),
417
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.001)),
418
+ needs_closure=True,
419
+ func='booth', steps=50, loss=40, merge_invariant=True,
420
+ sphere_steps=100, sphere_loss=450,
421
+ )
422
+ ForwardGradient_central = Run(
423
+ func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.01)),
424
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.001)),
425
+ needs_closure=True,
426
+ func='booth', steps=50, loss=40, merge_invariant=True,
427
+ sphere_steps=100, sphere_loss=450,
428
+ )
429
+ ForwardGradient_4samples = Run(
430
+ func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.1)),
431
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.001)),
432
+ needs_closure=True,
433
+ func='booth', steps=50, loss=0.1, merge_invariant=True,
434
+ sphere_steps=100, sphere_loss=400,
435
+ )
436
+ ForwardGradient_4samples_no_pre_generate = Run(
437
+ func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.1)),
438
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.001)),
439
+ needs_closure=True,
440
+ func='booth', steps=50, loss=0.1, merge_invariant=True,
441
+ sphere_steps=100, sphere_loss=400,
442
+ )
443
+
444
+ # ------------------------- line_search/backtracking ------------------------- #
445
+ Backtracking = Run(
446
+ func_opt=lambda p: tz.Modular(p, tz.m.Backtracking()),
447
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Backtracking()),
448
+ needs_closure=True,
449
+ func='booth', steps=50, loss=0, merge_invariant=True,
450
+ sphere_steps=2, sphere_loss=0,
451
+ )
452
+ Backtracking_try_negative = Run(
453
+ func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
454
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
455
+ needs_closure=True,
456
+ func='booth', steps=50, loss=1e-9, merge_invariant=True,
457
+ sphere_steps=2, sphere_loss=1e-10,
458
+ )
459
+ AdaptiveBacktracking = Run(
460
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
461
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
462
+ needs_closure=True,
463
+ func='booth', steps=50, loss=0, merge_invariant=True,
464
+ sphere_steps=2, sphere_loss=0,
465
+ )
466
+ AdaptiveBacktracking_try_negative = Run(
467
+ func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
468
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
469
+ needs_closure=True,
470
+ func='booth', steps=50, loss=1e-8, merge_invariant=True,
471
+ sphere_steps=2, sphere_loss=1e-10,
472
+ )
473
+ # ----------------------------- line_search/scipy ---------------------------- #
474
+ ScipyMinimizeScalar = Run(
475
+ func_opt=lambda p: tz.Modular(p, tz.m.ScipyMinimizeScalar(maxiter=10)),
476
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking(maxiter=10)),
477
+ needs_closure=True,
478
+ func='booth', steps=50, loss=1e-2, merge_invariant=True,
479
+ sphere_steps=2, sphere_loss=0,
480
+ )
481
+
482
+ # ------------------------- line_search/strong_wolfe ------------------------- #
483
+ StrongWolfe = Run(
484
+ func_opt=lambda p: tz.Modular(p, tz.m.StrongWolfe()),
485
+ sphere_opt=lambda p: tz.Modular(p, tz.m.StrongWolfe()),
486
+ needs_closure=True,
487
+ func='booth', steps=50, loss=0, merge_invariant=True,
488
+ sphere_steps=2, sphere_loss=0,
489
+ )
490
+
491
+ # ------------------------- line_search/trust_region ------------------------- #
492
+ TrustRegion = Run(
493
+ func_opt=lambda p: tz.Modular(p, tz.m.TrustRegion()),
494
+ sphere_opt=lambda p: tz.Modular(p, tz.m.TrustRegion(init=0.1)),
495
+ needs_closure=True,
496
+ func='booth', steps=50, loss=0.1, merge_invariant=True,
497
+ sphere_steps=10, sphere_loss=1e-5,
498
+ )
499
+
500
+ # ----------------------------------- lr/lr ---------------------------------- #
501
+ LR = Run(
502
+ func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
503
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LR(0.5)),
504
+ needs_closure=False,
505
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
506
+ sphere_steps=10, sphere_loss=0,
507
+ )
508
+ StepSize = Run(
509
+ func_opt=lambda p: tz.Modular(p, tz.m.StepSize(0.1)),
510
+ sphere_opt=lambda p: tz.Modular(p, tz.m.StepSize(0.5)),
511
+ needs_closure=False,
512
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
513
+ sphere_steps=10, sphere_loss=0,
514
+ )
515
+ Warmup = Run(
516
+ func_opt=lambda p: tz.Modular(p, tz.m.Warmup(steps=50, end_lr=0.1)),
517
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Warmup(steps=10)),
518
+ needs_closure=False,
519
+ func='booth', steps=50, loss=0.003, merge_invariant=True,
520
+ sphere_steps=10, sphere_loss=0.05,
521
+ )
522
+ # ------------------------------- lr/step_size ------------------------------- #
523
+ PolyakStepSize = Run(
524
+ func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
525
+ sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
526
+ needs_closure=True,
527
+ func='booth', steps=50, loss=1e-11, merge_invariant=True,
528
+ sphere_steps=10, sphere_loss=0.002,
529
+ )
530
+ RandomStepSize = Run(
531
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, seed=0)),
532
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, seed=0)),
533
+ needs_closure=False,
534
+ func='booth', steps=50, loss=0.0005, merge_invariant=True,
535
+ sphere_steps=10, sphere_loss=100,
536
+ )
537
+ RandomStepSize_parameterwise = Run(
538
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
539
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomStepSize(0,0.1, parameterwise=True, seed=0)),
540
+ needs_closure=False,
541
+ func='booth', steps=50, loss=0.0005, merge_invariant=False,
542
+ sphere_steps=10, sphere_loss=100,
543
+ )
544
+
545
+ # ---------------------------- momentum/averaging ---------------------------- #
546
+ Averaging = Run(
547
+ func_opt=lambda p: tz.Modular(p, tz.m.Averaging(10), tz.m.LR(0.02)),
548
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Averaging(10), tz.m.LR(0.2)),
549
+ needs_closure=False,
550
+ func='booth', steps=50, loss=0.5, merge_invariant=True,
551
+ sphere_steps=10, sphere_loss=0.05,
552
+ )
553
+ WeightedAveraging = Run(
554
+ func_opt=lambda p: tz.Modular(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.05)),
555
+ sphere_opt=lambda p: tz.Modular(p, tz.m.WeightedAveraging([1,0.75,0.5,0.25,0]), tz.m.LR(0.5)),
556
+ needs_closure=False,
557
+ func='booth', steps=50, loss=1, merge_invariant=True,
558
+ sphere_steps=10, sphere_loss=2,
559
+ )
560
+ MedianAveraging = Run(
561
+ func_opt=lambda p: tz.Modular(p, tz.m.MedianAveraging(10), tz.m.LR(0.05)),
562
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MedianAveraging(10), tz.m.LR(0.5)),
563
+ needs_closure=False,
564
+ func='booth', steps=50, loss=0.005, merge_invariant=True,
565
+ sphere_steps=10, sphere_loss=0,
566
+ )
567
+
568
+ # ----------------------------- momentum/cautious ---------------------------- #
569
+ Cautious = Run(
570
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
571
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Cautious(), tz.m.LR(0.1)),
572
+ needs_closure=False,
573
+ func='booth', steps=50, loss=0.003, merge_invariant=True,
574
+ sphere_steps=10, sphere_loss=2,
575
+ )
576
+ UpdateGradientSignConsistency = Run(
577
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
578
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.Mul(tz.m.UpdateGradientSignConsistency()), tz.m.LR(0.1)),
579
+ needs_closure=False,
580
+ func='booth', steps=50, loss=0.003, merge_invariant=True,
581
+ sphere_steps=10, sphere_loss=2,
582
+ )
583
+ IntermoduleCautious = Run(
584
+ func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.01)),
585
+ sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.1)),
586
+ needs_closure=False,
587
+ func='booth', steps=50, loss=1e-4, merge_invariant=True,
588
+ sphere_steps=10, sphere_loss=0.1,
589
+ )
590
+ ScaleByGradCosineSimilarity = Run(
591
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.01)),
592
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(0.9), tz.m.ScaleByGradCosineSimilarity(), tz.m.LR(0.1)),
593
+ needs_closure=False,
594
+ func='booth', steps=50, loss=0.1, merge_invariant=True,
595
+ sphere_steps=10, sphere_loss=0.1,
596
+ )
597
+ ScaleModulesByCosineSimilarity = Run(
598
+ func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.05)),
599
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.1)),
600
+ needs_closure=False,
601
+ func='booth', steps=50, loss=0.005, merge_invariant=True,
602
+ sphere_steps=10, sphere_loss=0.1,
603
+ )
604
+
605
+ # ------------------------- momentum/matrix_momentum ------------------------- #
606
+ MatrixMomentum_forward = Run(
607
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.01)),
608
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
609
+ needs_closure=True,
610
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
611
+ sphere_steps=10, sphere_loss=0,
612
+ )
613
+ MatrixMomentum_forward = Run(
614
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.01)),
615
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
616
+ needs_closure=True,
617
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
618
+ sphere_steps=10, sphere_loss=0,
619
+ )
620
+ MatrixMomentum_forward = Run(
621
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.01)),
622
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
623
+ needs_closure=True,
624
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
625
+ sphere_steps=10, sphere_loss=0,
626
+ )
627
+
628
+ AdaptiveMatrixMomentum_forward = Run(
629
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.05)),
630
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
631
+ needs_closure=True,
632
+ func='booth', steps=50, loss=0.002, merge_invariant=True,
633
+ sphere_steps=10, sphere_loss=0,
634
+ )
635
+ AdaptiveMatrixMomentum_central = Run(
636
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.05)),
637
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
638
+ needs_closure=True,
639
+ func='booth', steps=50, loss=0.002, merge_invariant=True,
640
+ sphere_steps=10, sphere_loss=0,
641
+ )
642
+ AdaptiveMatrixMomentum_autograd = Run(
643
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.05)),
644
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
645
+ needs_closure=True,
646
+ func='booth', steps=50, loss=0.002, merge_invariant=True,
647
+ sphere_steps=10, sphere_loss=0,
648
+ )
649
+
650
+ # EMA, momentum are covered by test_identical
651
+ # --------------------------------- ops/misc --------------------------------- #
652
+ Previous = Run(
653
+ func_opt=lambda p: tz.Modular(p, tz.m.Previous(10), tz.m.LR(0.05)),
654
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Previous(3), tz.m.LR(0.5)),
655
+ needs_closure=False,
656
+ func='booth', steps=50, loss=15, merge_invariant=True,
657
+ sphere_steps=10, sphere_loss=0,
658
+ )
659
+ GradSign = Run(
660
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.05)),
661
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.GradSign(), tz.m.LR(0.5)),
662
+ needs_closure=False,
663
+ func='booth', steps=50, loss=0.0002, merge_invariant=True,
664
+ sphere_steps=10, sphere_loss=0.1,
665
+ )
666
+ UpdateSign = Run(
667
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.05)),
668
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.UpdateSign(), tz.m.LR(0.5)),
669
+ needs_closure=False,
670
+ func='booth', steps=50, loss=0.01, merge_invariant=True,
671
+ sphere_steps=10, sphere_loss=0,
672
+ )
673
+ GradAccumulation = Run(
674
+ func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.05), 10), ),
675
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.5), 10), ),
676
+ needs_closure=False,
677
+ func='booth', steps=50, loss=25, merge_invariant=True,
678
+ sphere_steps=20, sphere_loss=1e-11,
679
+ )
680
+ NegateOnLossIncrease = Run(
681
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(),),
682
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(),),
683
+ needs_closure=True,
684
+ func='booth', steps=50, loss=0.1, merge_invariant=True,
685
+ sphere_steps=20, sphere_loss=0.001,
686
+ )
687
+ # -------------------------------- misc/switch ------------------------------- #
688
+ Alternate = Run(
689
+ func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
690
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
691
+ needs_closure=False,
692
+ func='booth', steps=50, loss=1, merge_invariant=True,
693
+ sphere_steps=20, sphere_loss=20,
694
+ )
695
+
696
+ # ------------------------------ optimizers/adam ----------------------------- #
697
+ Adam = Run(
698
+ func_opt=lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.5)),
699
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Adam(), tz.m.LR(0.2)),
700
+ needs_closure=False,
701
+ func='rosen', steps=50, loss=4, merge_invariant=True,
702
+ sphere_steps=20, sphere_loss=4,
703
+ )
704
+ # ------------------------------ optimizers/soap ----------------------------- #
705
+ SOAP = Run(
706
+ func_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(0.4)),
707
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(1)),
708
+ needs_closure=False,
709
+ func='rosen', steps=50, loss=4, merge_invariant=False,
710
+ sphere_steps=20, sphere_loss=25, # merge and unmerge lrs are very different so need to test convergence separately somewhere
711
+ )
712
+ # ------------------------------ optimizers/lion ----------------------------- #
713
+ Lion = Run(
714
+ func_opt=lambda p: tz.Modular(p, tz.m.Lion(), tz.m.LR(1)),
715
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Lion(), tz.m.LR(0.1)),
716
+ needs_closure=False,
717
+ func='booth', steps=50, loss=0, merge_invariant=True,
718
+ sphere_steps=20, sphere_loss=25,
719
+ )
720
+ # ---------------------------- optimizers/shampoo ---------------------------- #
721
+ Shampoo = Run(
722
+ func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
723
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.2)),
724
+ needs_closure=False,
725
+ func='booth', steps=50, loss=200, merge_invariant=False,
726
+ sphere_steps=20, sphere_loss=1e-4, # merge and unmerge lrs are very different so need to test convergence separately somewhere
727
+ )
728
+
729
+ # ------------------------- quasi_newton/quasi_newton ------------------------ #
730
+ BFGS = Run(
731
+ func_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
732
+ sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
733
+ needs_closure=True,
734
+ func='rosen', steps=50, loss=0, merge_invariant=True,
735
+ sphere_steps=10, sphere_loss=0,
736
+ )
737
+ SR1 = Run(
738
+ func_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
739
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
740
+ needs_closure=True,
741
+ func='rosen', steps=50, loss=1e-12, merge_invariant=True,
742
+ sphere_steps=10, sphere_loss=0,
743
+ )
744
+ SSVM = Run(
745
+ func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
+ needs_closure=True,
748
+ func='rosen', steps=50, loss=1e-12, merge_invariant=True,
749
+ sphere_steps=10, sphere_loss=0,
750
+ )
751
+
752
+ # ---------------------------- quasi_newton/lbfgs ---------------------------- #
753
+ LBFGS = Run(
754
+ func_opt=lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
755
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.StrongWolfe()),
756
+ needs_closure=True,
757
+ func='rosen', steps=50, loss=0, merge_invariant=True,
758
+ sphere_steps=10, sphere_loss=0,
759
+ )
760
+
761
+ # ----------------------------- quasi_newton/lsr1 ---------------------------- #
762
+ LSR1 = Run(
763
+ func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
764
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
765
+ needs_closure=True,
766
+ func='rosen', steps=50, loss=0, merge_invariant=True,
767
+ sphere_steps=10, sphere_loss=0,
768
+ )
769
+
770
+ # ---------------------------- quasi_newton/olbfgs --------------------------- #
771
+ OnlineLBFGS = Run(
772
+ func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
773
+ sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
774
+ needs_closure=True,
775
+ func='rosen', steps=50, loss=0, merge_invariant=True,
776
+ sphere_steps=10, sphere_loss=0,
777
+ )
778
+
779
+ # ---------------------------- second_order/newton --------------------------- #
780
+ Newton = Run(
781
+ func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
782
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
783
+ needs_closure=True,
784
+ func='rosen', steps=20, loss=1e-7, merge_invariant=True,
785
+ sphere_steps=2, sphere_loss=1e-9,
786
+ )
787
+
788
+ # --------------------------- second_order/newton_cg -------------------------- #
789
+ NewtonCG = Run(
790
+ func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
791
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
792
+ needs_closure=True,
793
+ func='rosen', steps=20, loss=1e-7, merge_invariant=True,
794
+ sphere_steps=2, sphere_loss=1e-6,
795
+ )
796
+
797
+ # ---------------------------- smoothing/gaussian ---------------------------- #
798
+ GaussianHomotopy = Run(
799
+ func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
800
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
801
+ needs_closure=True,
802
+ func='booth', steps=20, loss=0.1, merge_invariant=True,
803
+ sphere_steps=10, sphere_loss=150, # merge and unmerge lrs are very different so need to test convergence separately somewhere
804
+ )
805
+
806
+ # ---------------------------- smoothing/laplacian --------------------------- #
807
+ LaplacianSmoothing = Run(
808
+ func_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.1)),
809
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(min_numel=1), tz.m.LR(0.5)),
810
+ needs_closure=False,
811
+ func='booth', steps=50, loss=0.4, merge_invariant=False,
812
+ sphere_steps=10, sphere_loss=3, # merge and unmerge lrs are very different so need to test convergence separately somewhere
813
+ )
814
+
815
+ LaplacianSmoothing_global = Run(
816
+ func_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.1)),
817
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LaplacianSmoothing(layerwise=False), tz.m.LR(0.5)),
818
+ needs_closure=False,
819
+ func='booth', steps=50, loss=0.4, merge_invariant=True,
820
+ sphere_steps=10, sphere_loss=3, # merge and unmerge lrs are very different so need to test convergence separately somewhere
821
+ )
822
+
823
+ # -------------------------- wrappers/optim_wrapper -------------------------- #
824
+ Wrap = Run(
825
+ func_opt=lambda p: tz.Modular(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.5)),
826
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Wrap(torch.optim.Adam, lr=1), tz.m.LR(0.2)),
827
+ needs_closure=False,
828
+ func='rosen', steps=50, loss=4, merge_invariant=True,
829
+ sphere_steps=20, sphere_loss=4,
830
+ )
831
+
832
+ # --------------------------- second_order/nystrom --------------------------- #
833
+ NystromSketchAndSolve = Run(
834
+ func_opt=lambda p: tz.Modular(p, tz.m.NystromSketchAndSolve(2, seed=0), tz.m.StrongWolfe()),
835
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NystromSketchAndSolve(10, seed=0), tz.m.StrongWolfe()),
836
+ needs_closure=True,
837
+ func='booth', steps=3, loss=1e-8, merge_invariant=True,
838
+ sphere_steps=10, sphere_loss=1e-12,
839
+ )
840
+ NystromPCG = Run(
841
+ func_opt=lambda p: tz.Modular(p, tz.m.NystromPCG(2, seed=0), tz.m.StrongWolfe()),
842
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NystromPCG(10, seed=0), tz.m.StrongWolfe()),
843
+ needs_closure=True,
844
+ func='ill', steps=2, loss=1e-5, merge_invariant=True,
845
+ sphere_steps=2, sphere_loss=1e-9,
846
+ )
847
+
848
+ # ---------------------------- optimizers/sophia_h --------------------------- #
849
+ SophiaH = Run(
850
+ func_opt=lambda p: tz.Modular(p, tz.m.SophiaH(seed=0), tz.m.LR(0.1)),
851
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SophiaH(seed=0), tz.m.LR(0.3)),
852
+ needs_closure=True,
853
+ func='ill', steps=50, loss=0.02, merge_invariant=True,
854
+ sphere_steps=10, sphere_loss=40,
855
+ )
856
+
857
+ # ------------------------------------ CGs ----------------------------------- #
858
+ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY):
859
+ for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
860
+ # but also test 10 to make sure it doesn't explode after converging
861
+ Run(
862
+ func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
863
+ sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
864
+ needs_closure=True,
865
+ func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=False, # strong wolfe adds float imprecision
866
+ sphere_steps=sphere_steps_, sphere_loss=0,
867
+ )
868
+
869
+ # ------------------------------- QN stability ------------------------------- #
870
+ # stability test
871
+ for QN in (tz.m.BFGS, tz.m.SR1, tz.m.DFP, tz.m.BroydenGood, tz.m.BroydenBad, tz.m.Greenstadt1, tz.m.Greenstadt2, tz.m.ColumnUpdatingMethod, tz.m.ThomasOptimalMethod, tz.m.PSB, tz.m.Pearson2, tz.m.SSVM):
872
+ Run(
873
+ func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
874
+ sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
875
+ needs_closure=True,
876
+ func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
877
+ sphere_steps=10, sphere_loss=1e-20,
878
+ )
879
+
880
+ # ---------------------------------------------------------------------------- #
881
+ # run #
882
+ # ---------------------------------------------------------------------------- #
883
+ @pytest.mark.parametrize("run", RUNS)
884
+ def test_opt(run: Run): run.test()