torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
tests/test_identical.py
CHANGED
|
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
|
|
|
96
96
|
|
|
97
97
|
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
98
98
|
def test_adam(amsgrad):
|
|
99
|
-
|
|
100
|
-
# pytorch applies debiasing separately so it is applied before epsilo
|
|
99
|
+
torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
|
|
101
100
|
tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
|
|
102
101
|
tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
|
|
103
102
|
tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
|
|
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
|
|
|
133
132
|
tz.m.Debias2(beta=0.999),
|
|
134
133
|
tz.m.Add(1e-8)]
|
|
135
134
|
))
|
|
136
|
-
tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
135
|
+
tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
137
136
|
|
|
138
137
|
_assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
|
|
139
138
|
for fn in tz_fns:
|
tests/test_opts.py
CHANGED
|
@@ -56,14 +56,17 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
56
56
|
if use_closure:
|
|
57
57
|
def closure(backward=True):
|
|
58
58
|
loss = objective()
|
|
59
|
+
losses.append(loss.detach())
|
|
59
60
|
if backward:
|
|
60
61
|
opt.zero_grad()
|
|
61
62
|
loss.backward()
|
|
62
63
|
return loss
|
|
63
|
-
|
|
64
|
-
assert
|
|
65
|
-
|
|
66
|
-
|
|
64
|
+
ret = opt.step(closure)
|
|
65
|
+
assert ret is not None # the return should be the loss
|
|
66
|
+
with torch.no_grad():
|
|
67
|
+
loss = objective() # in case f(x_0) is not evaluated
|
|
68
|
+
assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
|
|
69
|
+
losses.append(loss.detach())
|
|
67
70
|
|
|
68
71
|
else:
|
|
69
72
|
loss = objective()
|
|
@@ -71,7 +74,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
71
74
|
loss.backward()
|
|
72
75
|
opt.step()
|
|
73
76
|
assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
|
|
74
|
-
losses.append(loss)
|
|
77
|
+
losses.append(loss.detach())
|
|
75
78
|
|
|
76
79
|
losses.append(objective())
|
|
77
80
|
return torch.stack(losses).nan_to_num(0,10000,10000).min()
|
|
@@ -292,42 +295,42 @@ FDM_central2 = Run(
|
|
|
292
295
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
|
|
293
296
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
|
|
294
297
|
needs_closure=True,
|
|
295
|
-
func='booth', steps=50, loss=1e-
|
|
298
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
296
299
|
sphere_steps=2, sphere_loss=340,
|
|
297
300
|
)
|
|
298
301
|
FDM_forward2 = Run(
|
|
299
302
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
|
|
300
303
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
|
|
301
304
|
needs_closure=True,
|
|
302
|
-
func='booth', steps=50, loss=1e-
|
|
305
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
303
306
|
sphere_steps=2, sphere_loss=340,
|
|
304
307
|
)
|
|
305
308
|
FDM_backward2 = Run(
|
|
306
309
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
|
|
307
310
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
|
|
308
311
|
needs_closure=True,
|
|
309
|
-
func='booth', steps=50, loss=
|
|
312
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
310
313
|
sphere_steps=2, sphere_loss=340,
|
|
311
314
|
)
|
|
312
315
|
FDM_forward3 = Run(
|
|
313
316
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
|
|
314
317
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
|
|
315
318
|
needs_closure=True,
|
|
316
|
-
func='booth', steps=50, loss=
|
|
319
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
317
320
|
sphere_steps=2, sphere_loss=340,
|
|
318
321
|
)
|
|
319
322
|
FDM_backward3 = Run(
|
|
320
323
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
|
|
321
324
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
|
|
322
325
|
needs_closure=True,
|
|
323
|
-
func='booth', steps=50, loss=
|
|
326
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
324
327
|
sphere_steps=2, sphere_loss=340,
|
|
325
328
|
)
|
|
326
329
|
FDM_central4 = Run(
|
|
327
330
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
|
|
328
331
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
|
|
329
332
|
needs_closure=True,
|
|
330
|
-
func='booth', steps=50, loss=
|
|
333
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
331
334
|
sphere_steps=2, sphere_loss=340,
|
|
332
335
|
)
|
|
333
336
|
|
|
@@ -374,6 +377,21 @@ RandomizedFDM_central4 = Run(
|
|
|
374
377
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
375
378
|
sphere_steps=100, sphere_loss=450,
|
|
376
379
|
)
|
|
380
|
+
RandomizedFDM_forward4 = Run(
|
|
381
|
+
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
|
|
382
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
|
|
383
|
+
needs_closure=True,
|
|
384
|
+
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
385
|
+
sphere_steps=100, sphere_loss=450,
|
|
386
|
+
)
|
|
387
|
+
RandomizedFDM_forward5 = Run(
|
|
388
|
+
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
|
|
389
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
|
|
390
|
+
needs_closure=True,
|
|
391
|
+
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
392
|
+
sphere_steps=100, sphere_loss=450,
|
|
393
|
+
)
|
|
394
|
+
|
|
377
395
|
|
|
378
396
|
RandomizedFDM_4samples = Run(
|
|
379
397
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
|
|
@@ -455,25 +473,11 @@ Backtracking = Run(
|
|
|
455
473
|
func='booth', steps=50, loss=0, merge_invariant=True,
|
|
456
474
|
sphere_steps=2, sphere_loss=0,
|
|
457
475
|
)
|
|
458
|
-
Backtracking_try_negative = Run(
|
|
459
|
-
func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
|
|
460
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
|
|
461
|
-
needs_closure=True,
|
|
462
|
-
func='booth', steps=50, loss=1e-9, merge_invariant=True,
|
|
463
|
-
sphere_steps=2, sphere_loss=1e-10,
|
|
464
|
-
)
|
|
465
476
|
AdaptiveBacktracking = Run(
|
|
466
477
|
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
467
478
|
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
468
479
|
needs_closure=True,
|
|
469
|
-
func='booth', steps=50, loss=
|
|
470
|
-
sphere_steps=2, sphere_loss=0,
|
|
471
|
-
)
|
|
472
|
-
AdaptiveBacktracking_try_negative = Run(
|
|
473
|
-
func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
|
|
474
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
|
|
475
|
-
needs_closure=True,
|
|
476
|
-
func='booth', steps=50, loss=1e-8, merge_invariant=True,
|
|
480
|
+
func='booth', steps=50, loss=1e-11, merge_invariant=True,
|
|
477
481
|
sphere_steps=2, sphere_loss=1e-10,
|
|
478
482
|
)
|
|
479
483
|
# ----------------------------- line_search/scipy ---------------------------- #
|
|
@@ -494,15 +498,6 @@ StrongWolfe = Run(
|
|
|
494
498
|
sphere_steps=2, sphere_loss=0,
|
|
495
499
|
)
|
|
496
500
|
|
|
497
|
-
# ------------------------- line_search/trust_region ------------------------- #
|
|
498
|
-
TrustRegion = Run(
|
|
499
|
-
func_opt=lambda p: tz.Modular(p, tz.m.TrustRegion()),
|
|
500
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.TrustRegion(init=0.1)),
|
|
501
|
-
needs_closure=True,
|
|
502
|
-
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
503
|
-
sphere_steps=10, sphere_loss=1e-5,
|
|
504
|
-
)
|
|
505
|
-
|
|
506
501
|
# ----------------------------------- lr/lr ---------------------------------- #
|
|
507
502
|
LR = Run(
|
|
508
503
|
func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
|
|
@@ -587,8 +582,8 @@ UpdateGradientSignConsistency = Run(
|
|
|
587
582
|
sphere_steps=10, sphere_loss=2,
|
|
588
583
|
)
|
|
589
584
|
IntermoduleCautious = Run(
|
|
590
|
-
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.01)),
|
|
591
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.1)),
|
|
585
|
+
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
|
|
586
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
|
|
592
587
|
needs_closure=False,
|
|
593
588
|
func='booth', steps=50, loss=1e-4, merge_invariant=True,
|
|
594
589
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -601,8 +596,8 @@ ScaleByGradCosineSimilarity = Run(
|
|
|
601
596
|
sphere_steps=10, sphere_loss=0.1,
|
|
602
597
|
)
|
|
603
598
|
ScaleModulesByCosineSimilarity = Run(
|
|
604
|
-
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.05)),
|
|
605
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.1)),
|
|
599
|
+
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
|
|
600
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
|
|
606
601
|
needs_closure=False,
|
|
607
602
|
func='booth', steps=50, loss=0.005, merge_invariant=True,
|
|
608
603
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -610,47 +605,69 @@ ScaleModulesByCosineSimilarity = Run(
|
|
|
610
605
|
|
|
611
606
|
# ------------------------- momentum/matrix_momentum ------------------------- #
|
|
612
607
|
MatrixMomentum_forward = Run(
|
|
613
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'),
|
|
614
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward')
|
|
608
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='forward'),),
|
|
609
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward')),
|
|
615
610
|
needs_closure=True,
|
|
616
611
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
617
|
-
sphere_steps=10, sphere_loss=0,
|
|
612
|
+
sphere_steps=10, sphere_loss=0.01,
|
|
618
613
|
)
|
|
619
614
|
MatrixMomentum_forward = Run(
|
|
620
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central')
|
|
621
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central')
|
|
615
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='central')),
|
|
616
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central')),
|
|
622
617
|
needs_closure=True,
|
|
623
618
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
624
|
-
sphere_steps=10, sphere_loss=0,
|
|
619
|
+
sphere_steps=10, sphere_loss=0.01,
|
|
625
620
|
)
|
|
626
621
|
MatrixMomentum_forward = Run(
|
|
627
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd')
|
|
628
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd')
|
|
622
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
|
|
623
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
|
|
629
624
|
needs_closure=True,
|
|
630
625
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
631
|
-
sphere_steps=10, sphere_loss=0,
|
|
626
|
+
sphere_steps=10, sphere_loss=0.01,
|
|
632
627
|
)
|
|
633
628
|
|
|
634
629
|
AdaptiveMatrixMomentum_forward = Run(
|
|
635
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
636
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
630
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True)),
|
|
631
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True)),
|
|
637
632
|
needs_closure=True,
|
|
638
|
-
func='booth', steps=50, loss=0.
|
|
639
|
-
sphere_steps=10, sphere_loss=0,
|
|
633
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
634
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
640
635
|
)
|
|
641
636
|
AdaptiveMatrixMomentum_central = Run(
|
|
642
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
643
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
637
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True)),
|
|
638
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True)),
|
|
644
639
|
needs_closure=True,
|
|
645
|
-
func='booth', steps=50, loss=0.
|
|
646
|
-
sphere_steps=10, sphere_loss=0,
|
|
640
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
641
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
647
642
|
)
|
|
648
643
|
AdaptiveMatrixMomentum_autograd = Run(
|
|
649
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
650
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
644
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
|
|
645
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
|
|
651
646
|
needs_closure=True,
|
|
652
|
-
func='booth', steps=50, loss=0.
|
|
653
|
-
sphere_steps=10, sphere_loss=0,
|
|
647
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
648
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
StochasticAdaptiveMatrixMomentum_forward = Run(
|
|
652
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True, adapt_freq=1)),
|
|
653
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True, adapt_freq=1)),
|
|
654
|
+
needs_closure=True,
|
|
655
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
656
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
657
|
+
)
|
|
658
|
+
StochasticAdaptiveMatrixMomentum_central = Run(
|
|
659
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True, adapt_freq=1)),
|
|
660
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True, adapt_freq=1)),
|
|
661
|
+
needs_closure=True,
|
|
662
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
663
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
664
|
+
)
|
|
665
|
+
StochasticAdaptiveMatrixMomentum_autograd = Run(
|
|
666
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
|
|
667
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
|
|
668
|
+
needs_closure=True,
|
|
669
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
670
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
654
671
|
)
|
|
655
672
|
|
|
656
673
|
# EMA, momentum are covered by test_identical
|
|
@@ -677,15 +694,15 @@ UpdateSign = Run(
|
|
|
677
694
|
sphere_steps=10, sphere_loss=0,
|
|
678
695
|
)
|
|
679
696
|
GradAccumulation = Run(
|
|
680
|
-
func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.05)
|
|
681
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.5)
|
|
697
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
|
|
698
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
|
|
682
699
|
needs_closure=False,
|
|
683
700
|
func='booth', steps=50, loss=25, merge_invariant=True,
|
|
684
701
|
sphere_steps=20, sphere_loss=1e-11,
|
|
685
702
|
)
|
|
686
703
|
NegateOnLossIncrease = Run(
|
|
687
|
-
func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(),),
|
|
688
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(),),
|
|
704
|
+
func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
|
|
705
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
|
|
689
706
|
needs_closure=True,
|
|
690
707
|
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
691
708
|
sphere_steps=20, sphere_loss=0.001,
|
|
@@ -693,7 +710,7 @@ NegateOnLossIncrease = Run(
|
|
|
693
710
|
# -------------------------------- misc/switch ------------------------------- #
|
|
694
711
|
Alternate = Run(
|
|
695
712
|
func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
|
|
696
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
|
|
713
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
|
|
697
714
|
needs_closure=False,
|
|
698
715
|
func='booth', steps=50, loss=1, merge_invariant=True,
|
|
699
716
|
sphere_steps=20, sphere_loss=20,
|
|
@@ -734,24 +751,24 @@ Shampoo = Run(
|
|
|
734
751
|
|
|
735
752
|
# ------------------------- quasi_newton/quasi_newton ------------------------ #
|
|
736
753
|
BFGS = Run(
|
|
737
|
-
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
738
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
754
|
+
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
|
|
755
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
|
|
739
756
|
needs_closure=True,
|
|
740
|
-
func='rosen', steps=50, loss=
|
|
741
|
-
sphere_steps=10, sphere_loss=
|
|
757
|
+
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
758
|
+
sphere_steps=10, sphere_loss=1e-10,
|
|
742
759
|
)
|
|
743
760
|
SR1 = Run(
|
|
744
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
|
|
745
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
|
|
761
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_restart=True, scale_first=True), tz.m.StrongWolfe(fallback=False)),
|
|
762
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
|
|
746
763
|
needs_closure=True,
|
|
747
764
|
func='rosen', steps=50, loss=1e-12, merge_invariant=True,
|
|
748
765
|
sphere_steps=10, sphere_loss=0,
|
|
749
766
|
)
|
|
750
767
|
SSVM = Run(
|
|
751
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
752
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
768
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
|
|
769
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
|
|
753
770
|
needs_closure=True,
|
|
754
|
-
func='rosen', steps=50, loss=
|
|
771
|
+
func='rosen', steps=50, loss=0.2, merge_invariant=True,
|
|
755
772
|
sphere_steps=10, sphere_loss=0,
|
|
756
773
|
)
|
|
757
774
|
|
|
@@ -766,26 +783,26 @@ LBFGS = Run(
|
|
|
766
783
|
|
|
767
784
|
# ----------------------------- quasi_newton/lsr1 ---------------------------- #
|
|
768
785
|
LSR1 = Run(
|
|
769
|
-
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
|
|
770
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
|
|
786
|
+
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
|
|
787
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
|
|
771
788
|
needs_closure=True,
|
|
772
789
|
func='rosen', steps=50, loss=0, merge_invariant=True,
|
|
773
790
|
sphere_steps=10, sphere_loss=0,
|
|
774
791
|
)
|
|
775
792
|
|
|
776
|
-
# ---------------------------- quasi_newton/olbfgs --------------------------- #
|
|
777
|
-
OnlineLBFGS = Run(
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
)
|
|
793
|
+
# # ---------------------------- quasi_newton/olbfgs --------------------------- #
|
|
794
|
+
# OnlineLBFGS = Run(
|
|
795
|
+
# func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
|
|
796
|
+
# sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
|
|
797
|
+
# needs_closure=True,
|
|
798
|
+
# func='rosen', steps=50, loss=0, merge_invariant=True,
|
|
799
|
+
# sphere_steps=10, sphere_loss=0,
|
|
800
|
+
# )
|
|
784
801
|
|
|
785
802
|
# ---------------------------- second_order/newton --------------------------- #
|
|
786
803
|
Newton = Run(
|
|
787
|
-
func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
|
|
788
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
|
|
804
|
+
func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
|
|
805
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
|
|
789
806
|
needs_closure=True,
|
|
790
807
|
func='rosen', steps=20, loss=1e-7, merge_invariant=True,
|
|
791
808
|
sphere_steps=2, sphere_loss=1e-9,
|
|
@@ -793,8 +810,8 @@ Newton = Run(
|
|
|
793
810
|
|
|
794
811
|
# --------------------------- second_order/newton_cg -------------------------- #
|
|
795
812
|
NewtonCG = Run(
|
|
796
|
-
func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
|
|
797
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
|
|
813
|
+
func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
|
|
814
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
|
|
798
815
|
needs_closure=True,
|
|
799
816
|
func='rosen', steps=20, loss=1e-7, merge_invariant=True,
|
|
800
817
|
sphere_steps=2, sphere_loss=3e-4,
|
|
@@ -802,11 +819,11 @@ NewtonCG = Run(
|
|
|
802
819
|
|
|
803
820
|
# ---------------------------- smoothing/gaussian ---------------------------- #
|
|
804
821
|
GaussianHomotopy = Run(
|
|
805
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
806
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
822
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
|
|
823
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
|
|
807
824
|
needs_closure=True,
|
|
808
|
-
func='booth', steps=20, loss=0.
|
|
809
|
-
sphere_steps=10, sphere_loss=
|
|
825
|
+
func='booth', steps=20, loss=0.01, merge_invariant=True,
|
|
826
|
+
sphere_steps=10, sphere_loss=1,
|
|
810
827
|
)
|
|
811
828
|
|
|
812
829
|
# ---------------------------- smoothing/laplacian --------------------------- #
|
|
@@ -860,7 +877,7 @@ SophiaH = Run(
|
|
|
860
877
|
sphere_steps=10, sphere_loss=40,
|
|
861
878
|
)
|
|
862
879
|
|
|
863
|
-
# --------------------------
|
|
880
|
+
# -------------------------- higher_order ------------------------- #
|
|
864
881
|
HigherOrderNewton = Run(
|
|
865
882
|
func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
|
|
866
883
|
sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
|
|
@@ -869,15 +886,33 @@ HigherOrderNewton = Run(
|
|
|
869
886
|
sphere_steps=1, sphere_loss=1e-10,
|
|
870
887
|
)
|
|
871
888
|
|
|
889
|
+
# ---------------------------- optimizers/ladagrad --------------------------- #
|
|
890
|
+
LMAdagrad = Run(
|
|
891
|
+
func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
|
|
892
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
|
|
893
|
+
needs_closure=False,
|
|
894
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
895
|
+
sphere_steps=20, sphere_loss=1e-9,
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
# ------------------------------ optimizers/adan ----------------------------- #
|
|
899
|
+
Adan = Run(
|
|
900
|
+
func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
|
|
901
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
|
|
902
|
+
needs_closure=False,
|
|
903
|
+
func='booth', steps=50, loss=60, merge_invariant=True,
|
|
904
|
+
sphere_steps=20, sphere_loss=60,
|
|
905
|
+
)
|
|
906
|
+
|
|
872
907
|
# ------------------------------------ CGs ----------------------------------- #
|
|
873
|
-
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.
|
|
908
|
+
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.DYHS, tz.m.ProjectedGradientMethod):
|
|
874
909
|
for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
|
|
875
910
|
# but also test 10 to make sure it doesn't explode after converging
|
|
876
911
|
Run(
|
|
877
912
|
func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
|
|
878
913
|
sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
|
|
879
914
|
needs_closure=True,
|
|
880
|
-
func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=
|
|
915
|
+
func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=True,
|
|
881
916
|
sphere_steps=sphere_steps_, sphere_loss=0,
|
|
882
917
|
)
|
|
883
918
|
|
|
@@ -885,17 +920,22 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
|
|
|
885
920
|
# stability test
|
|
886
921
|
for QN in (
|
|
887
922
|
tz.m.BFGS,
|
|
923
|
+
partial(tz.m.BFGS, inverse=False),
|
|
888
924
|
tz.m.SR1,
|
|
925
|
+
partial(tz.m.SR1, inverse=False),
|
|
889
926
|
tz.m.DFP,
|
|
927
|
+
partial(tz.m.DFP, inverse=False),
|
|
890
928
|
tz.m.BroydenGood,
|
|
929
|
+
partial(tz.m.BroydenGood, inverse=False),
|
|
891
930
|
tz.m.BroydenBad,
|
|
931
|
+
partial(tz.m.BroydenBad, inverse=False),
|
|
892
932
|
tz.m.Greenstadt1,
|
|
893
933
|
tz.m.Greenstadt2,
|
|
894
|
-
tz.m.
|
|
934
|
+
tz.m.ICUM,
|
|
895
935
|
tz.m.ThomasOptimalMethod,
|
|
896
936
|
tz.m.FletcherVMM,
|
|
897
937
|
tz.m.Horisho,
|
|
898
|
-
|
|
938
|
+
partial(tz.m.Horisho, inner=tz.m.GradientCorrection()),
|
|
899
939
|
tz.m.Pearson,
|
|
900
940
|
tz.m.ProjectedNewtonRaphson,
|
|
901
941
|
tz.m.PSB,
|
|
@@ -903,10 +943,10 @@ for QN in (
|
|
|
903
943
|
tz.m.SSVM,
|
|
904
944
|
):
|
|
905
945
|
Run(
|
|
906
|
-
func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
907
|
-
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
946
|
+
func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
|
|
947
|
+
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
|
|
908
948
|
needs_closure=True,
|
|
909
|
-
func='lstsq', steps=50, loss=1e-10, merge_invariant=
|
|
949
|
+
func='lstsq', steps=50, loss=1e-10, merge_invariant=True,
|
|
910
950
|
sphere_steps=10, sphere_loss=1e-20,
|
|
911
951
|
)
|
|
912
952
|
|
tests/test_tensorlist.py
CHANGED
|
@@ -977,22 +977,23 @@ def test_rademacher_like(big_tl: TensorList):
|
|
|
977
977
|
|
|
978
978
|
@pytest.mark.parametrize("dist", ['normal', 'uniform', 'sphere', 'rademacher'])
|
|
979
979
|
def test_sample_like(simple_tl: TensorList, dist):
|
|
980
|
-
eps_scalar =
|
|
981
|
-
result_tl_scalar = simple_tl.sample_like(
|
|
980
|
+
eps_scalar = 1
|
|
981
|
+
result_tl_scalar = simple_tl.sample_like(distribution=dist)
|
|
982
982
|
assert isinstance(result_tl_scalar, TensorList)
|
|
983
983
|
assert result_tl_scalar.shape == simple_tl.shape
|
|
984
984
|
|
|
985
|
-
eps_list = [
|
|
986
|
-
result_tl_list = simple_tl.sample_like(
|
|
985
|
+
eps_list = [1.0,]
|
|
986
|
+
result_tl_list = simple_tl.sample_like(distribution=dist)
|
|
987
987
|
assert isinstance(result_tl_list, TensorList)
|
|
988
988
|
assert result_tl_list.shape == simple_tl.shape
|
|
989
989
|
|
|
990
990
|
# Basic checks based on distribution
|
|
991
991
|
if dist == 'uniform':
|
|
992
|
-
assert all(torch.all((t >= -eps_scalar
|
|
993
|
-
assert all(torch.all((t >= -e
|
|
992
|
+
assert all(torch.all((t >= -eps_scalar) & (t <= eps_scalar)) for t in result_tl_scalar)
|
|
993
|
+
assert all(torch.all((t >= -e) & (t <= e)) for t, e in zip(result_tl_list, eps_list))
|
|
994
994
|
elif dist == 'sphere':
|
|
995
|
-
assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
|
|
995
|
+
# assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
|
|
996
|
+
pass
|
|
996
997
|
# Cannot check list version easily
|
|
997
998
|
elif dist == 'rademacher':
|
|
998
999
|
assert all(torch.all((t == -eps_scalar) | (t == eps_scalar)) for t in result_tl_scalar)
|
tests/test_vars.py
CHANGED
|
@@ -156,6 +156,7 @@ def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
|
|
|
156
156
|
for k,v in v1.__dict__.items():
|
|
157
157
|
if not k.startswith('__'):
|
|
158
158
|
# if k == 'post_step_hooks': continue
|
|
159
|
+
if k == 'storage': continue
|
|
159
160
|
if k == 'update' and clone_update:
|
|
160
161
|
if v1.update is None or v2.update is None:
|
|
161
162
|
assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
|
torchzero/__init__.py
CHANGED
torchzero/core/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from .module import
|
|
2
|
-
from .transform import
|
|
1
|
+
from .module import Chain, Chainable, Modular, Module, Var, maybe_chain
|
|
2
|
+
from .transform import Target, TensorwiseTransform, Transform, apply_transform
|