torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
tests/test_opts.py
CHANGED
|
@@ -56,14 +56,17 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
56
56
|
if use_closure:
|
|
57
57
|
def closure(backward=True):
|
|
58
58
|
loss = objective()
|
|
59
|
+
losses.append(loss.detach())
|
|
59
60
|
if backward:
|
|
60
61
|
opt.zero_grad()
|
|
61
62
|
loss.backward()
|
|
62
63
|
return loss
|
|
63
|
-
|
|
64
|
-
assert
|
|
65
|
-
|
|
66
|
-
|
|
64
|
+
ret = opt.step(closure)
|
|
65
|
+
assert ret is not None # the return should be the loss
|
|
66
|
+
with torch.no_grad():
|
|
67
|
+
loss = objective() # in case f(x_0) is not evaluated
|
|
68
|
+
assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
|
|
69
|
+
losses.append(loss.detach())
|
|
67
70
|
|
|
68
71
|
else:
|
|
69
72
|
loss = objective()
|
|
@@ -71,7 +74,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
71
74
|
loss.backward()
|
|
72
75
|
opt.step()
|
|
73
76
|
assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
|
|
74
|
-
losses.append(loss)
|
|
77
|
+
losses.append(loss.detach())
|
|
75
78
|
|
|
76
79
|
losses.append(objective())
|
|
77
80
|
return torch.stack(losses).nan_to_num(0,10000,10000).min()
|
|
@@ -374,6 +377,21 @@ RandomizedFDM_central4 = Run(
|
|
|
374
377
|
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
375
378
|
sphere_steps=100, sphere_loss=450,
|
|
376
379
|
)
|
|
380
|
+
RandomizedFDM_forward4 = Run(
|
|
381
|
+
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
|
|
382
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
|
|
383
|
+
needs_closure=True,
|
|
384
|
+
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
385
|
+
sphere_steps=100, sphere_loss=450,
|
|
386
|
+
)
|
|
387
|
+
RandomizedFDM_forward5 = Run(
|
|
388
|
+
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
|
|
389
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
|
|
390
|
+
needs_closure=True,
|
|
391
|
+
func='booth', steps=50, loss=10, merge_invariant=True,
|
|
392
|
+
sphere_steps=100, sphere_loss=450,
|
|
393
|
+
)
|
|
394
|
+
|
|
377
395
|
|
|
378
396
|
RandomizedFDM_4samples = Run(
|
|
379
397
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
|
|
@@ -382,13 +400,6 @@ RandomizedFDM_4samples = Run(
|
|
|
382
400
|
func='booth', steps=50, loss=1e-5, merge_invariant=True,
|
|
383
401
|
sphere_steps=100, sphere_loss=400,
|
|
384
402
|
)
|
|
385
|
-
RandomizedFDM_4samples_lerp = Run(
|
|
386
|
-
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.99, seed=0), tz.m.LR(0.1)),
|
|
387
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.9, seed=0), tz.m.LR(0.001)),
|
|
388
|
-
needs_closure=True,
|
|
389
|
-
func='booth', steps=50, loss=1e-5, merge_invariant=True,
|
|
390
|
-
sphere_steps=100, sphere_loss=505,
|
|
391
|
-
)
|
|
392
403
|
RandomizedFDM_4samples_no_pre_generate = Run(
|
|
393
404
|
func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
|
|
394
405
|
sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
|
|
@@ -455,25 +466,11 @@ Backtracking = Run(
|
|
|
455
466
|
func='booth', steps=50, loss=0, merge_invariant=True,
|
|
456
467
|
sphere_steps=2, sphere_loss=0,
|
|
457
468
|
)
|
|
458
|
-
Backtracking_try_negative = Run(
|
|
459
|
-
func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
|
|
460
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
|
|
461
|
-
needs_closure=True,
|
|
462
|
-
func='booth', steps=50, loss=1e-9, merge_invariant=True,
|
|
463
|
-
sphere_steps=2, sphere_loss=1e-10,
|
|
464
|
-
)
|
|
465
469
|
AdaptiveBacktracking = Run(
|
|
466
470
|
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
467
471
|
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
468
472
|
needs_closure=True,
|
|
469
|
-
func='booth', steps=50, loss=1e-
|
|
470
|
-
sphere_steps=2, sphere_loss=1e-10,
|
|
471
|
-
)
|
|
472
|
-
AdaptiveBacktracking_try_negative = Run(
|
|
473
|
-
func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
|
|
474
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
|
|
475
|
-
needs_closure=True,
|
|
476
|
-
func='booth', steps=50, loss=1e-8, merge_invariant=True,
|
|
473
|
+
func='booth', steps=50, loss=1e-11, merge_invariant=True,
|
|
477
474
|
sphere_steps=2, sphere_loss=1e-10,
|
|
478
475
|
)
|
|
479
476
|
# ----------------------------- line_search/scipy ---------------------------- #
|
|
@@ -578,8 +575,8 @@ UpdateGradientSignConsistency = Run(
|
|
|
578
575
|
sphere_steps=10, sphere_loss=2,
|
|
579
576
|
)
|
|
580
577
|
IntermoduleCautious = Run(
|
|
581
|
-
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(
|
|
582
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(
|
|
578
|
+
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
|
|
579
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
|
|
583
580
|
needs_closure=False,
|
|
584
581
|
func='booth', steps=50, loss=1e-4, merge_invariant=True,
|
|
585
582
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -592,8 +589,8 @@ ScaleByGradCosineSimilarity = Run(
|
|
|
592
589
|
sphere_steps=10, sphere_loss=0.1,
|
|
593
590
|
)
|
|
594
591
|
ScaleModulesByCosineSimilarity = Run(
|
|
595
|
-
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(
|
|
596
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(
|
|
592
|
+
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
|
|
593
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
|
|
597
594
|
needs_closure=False,
|
|
598
595
|
func='booth', steps=50, loss=0.005, merge_invariant=True,
|
|
599
596
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -601,47 +598,69 @@ ScaleModulesByCosineSimilarity = Run(
|
|
|
601
598
|
|
|
602
599
|
# ------------------------- momentum/matrix_momentum ------------------------- #
|
|
603
600
|
MatrixMomentum_forward = Run(
|
|
604
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'),
|
|
605
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward')
|
|
601
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='forward'),),
|
|
602
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward')),
|
|
606
603
|
needs_closure=True,
|
|
607
604
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
608
|
-
sphere_steps=10, sphere_loss=0,
|
|
605
|
+
sphere_steps=10, sphere_loss=0.01,
|
|
609
606
|
)
|
|
610
607
|
MatrixMomentum_forward = Run(
|
|
611
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central')
|
|
612
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central')
|
|
608
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='central')),
|
|
609
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central')),
|
|
613
610
|
needs_closure=True,
|
|
614
611
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
615
|
-
sphere_steps=10, sphere_loss=0,
|
|
612
|
+
sphere_steps=10, sphere_loss=0.01,
|
|
616
613
|
)
|
|
617
614
|
MatrixMomentum_forward = Run(
|
|
618
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd')
|
|
619
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd')
|
|
615
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
|
|
616
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
|
|
620
617
|
needs_closure=True,
|
|
621
618
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
622
|
-
sphere_steps=10, sphere_loss=0,
|
|
619
|
+
sphere_steps=10, sphere_loss=0.01,
|
|
623
620
|
)
|
|
624
621
|
|
|
625
622
|
AdaptiveMatrixMomentum_forward = Run(
|
|
626
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
627
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
623
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True)),
|
|
624
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True)),
|
|
628
625
|
needs_closure=True,
|
|
629
|
-
func='booth', steps=50, loss=0.
|
|
630
|
-
sphere_steps=10, sphere_loss=0,
|
|
626
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
627
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
631
628
|
)
|
|
632
629
|
AdaptiveMatrixMomentum_central = Run(
|
|
633
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
634
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
630
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True)),
|
|
631
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True)),
|
|
635
632
|
needs_closure=True,
|
|
636
|
-
func='booth', steps=50, loss=0.
|
|
637
|
-
sphere_steps=10, sphere_loss=0,
|
|
633
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
634
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
638
635
|
)
|
|
639
636
|
AdaptiveMatrixMomentum_autograd = Run(
|
|
640
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
641
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
637
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
|
|
638
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
|
|
642
639
|
needs_closure=True,
|
|
643
|
-
func='booth', steps=50, loss=0.
|
|
644
|
-
sphere_steps=10, sphere_loss=0,
|
|
640
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
641
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
StochasticAdaptiveMatrixMomentum_forward = Run(
|
|
645
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True, adapt_freq=1)),
|
|
646
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True, adapt_freq=1)),
|
|
647
|
+
needs_closure=True,
|
|
648
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
649
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
650
|
+
)
|
|
651
|
+
StochasticAdaptiveMatrixMomentum_central = Run(
|
|
652
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True, adapt_freq=1)),
|
|
653
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True, adapt_freq=1)),
|
|
654
|
+
needs_closure=True,
|
|
655
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
656
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
657
|
+
)
|
|
658
|
+
StochasticAdaptiveMatrixMomentum_autograd = Run(
|
|
659
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
|
|
660
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
|
|
661
|
+
needs_closure=True,
|
|
662
|
+
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
663
|
+
sphere_steps=10, sphere_loss=0.05,
|
|
645
664
|
)
|
|
646
665
|
|
|
647
666
|
# EMA, momentum are covered by test_identical
|
|
@@ -668,8 +687,8 @@ UpdateSign = Run(
|
|
|
668
687
|
sphere_steps=10, sphere_loss=0,
|
|
669
688
|
)
|
|
670
689
|
GradAccumulation = Run(
|
|
671
|
-
func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.05)
|
|
672
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.5)
|
|
690
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
|
|
691
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
|
|
673
692
|
needs_closure=False,
|
|
674
693
|
func='booth', steps=50, loss=25, merge_invariant=True,
|
|
675
694
|
sphere_steps=20, sphere_loss=1e-11,
|
|
@@ -725,24 +744,24 @@ Shampoo = Run(
|
|
|
725
744
|
|
|
726
745
|
# ------------------------- quasi_newton/quasi_newton ------------------------ #
|
|
727
746
|
BFGS = Run(
|
|
728
|
-
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(
|
|
729
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(
|
|
747
|
+
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
|
|
748
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
|
|
730
749
|
needs_closure=True,
|
|
731
750
|
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
732
751
|
sphere_steps=10, sphere_loss=1e-10,
|
|
733
752
|
)
|
|
734
753
|
SR1 = Run(
|
|
735
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SR1(
|
|
736
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(
|
|
754
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_restart=True, scale_first=True), tz.m.StrongWolfe(fallback=False)),
|
|
755
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
|
|
737
756
|
needs_closure=True,
|
|
738
757
|
func='rosen', steps=50, loss=1e-12, merge_invariant=True,
|
|
739
758
|
sphere_steps=10, sphere_loss=0,
|
|
740
759
|
)
|
|
741
760
|
SSVM = Run(
|
|
742
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1,
|
|
743
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1,
|
|
761
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
|
|
762
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
|
|
744
763
|
needs_closure=True,
|
|
745
|
-
func='rosen', steps=50, loss=0.
|
|
764
|
+
func='rosen', steps=50, loss=0.2, merge_invariant=True,
|
|
746
765
|
sphere_steps=10, sphere_loss=0,
|
|
747
766
|
)
|
|
748
767
|
|
|
@@ -757,8 +776,8 @@ LBFGS = Run(
|
|
|
757
776
|
|
|
758
777
|
# ----------------------------- quasi_newton/lsr1 ---------------------------- #
|
|
759
778
|
LSR1 = Run(
|
|
760
|
-
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(
|
|
761
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(
|
|
779
|
+
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
|
|
780
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
|
|
762
781
|
needs_closure=True,
|
|
763
782
|
func='rosen', steps=50, loss=0, merge_invariant=True,
|
|
764
783
|
sphere_steps=10, sphere_loss=0,
|
|
@@ -775,8 +794,8 @@ LSR1 = Run(
|
|
|
775
794
|
|
|
776
795
|
# ---------------------------- second_order/newton --------------------------- #
|
|
777
796
|
Newton = Run(
|
|
778
|
-
func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
|
|
779
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
|
|
797
|
+
func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
|
|
798
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
|
|
780
799
|
needs_closure=True,
|
|
781
800
|
func='rosen', steps=20, loss=1e-7, merge_invariant=True,
|
|
782
801
|
sphere_steps=2, sphere_loss=1e-9,
|
|
@@ -784,8 +803,8 @@ Newton = Run(
|
|
|
784
803
|
|
|
785
804
|
# --------------------------- second_order/newton_cg -------------------------- #
|
|
786
805
|
NewtonCG = Run(
|
|
787
|
-
func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
|
|
788
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
|
|
806
|
+
func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
|
|
807
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
|
|
789
808
|
needs_closure=True,
|
|
790
809
|
func='rosen', steps=20, loss=1e-7, merge_invariant=True,
|
|
791
810
|
sphere_steps=2, sphere_loss=3e-4,
|
|
@@ -793,11 +812,11 @@ NewtonCG = Run(
|
|
|
793
812
|
|
|
794
813
|
# ---------------------------- smoothing/gaussian ---------------------------- #
|
|
795
814
|
GaussianHomotopy = Run(
|
|
796
|
-
func_opt=lambda p: tz.Modular(p, tz.m.
|
|
797
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.
|
|
815
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
|
|
816
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
|
|
798
817
|
needs_closure=True,
|
|
799
|
-
func='booth', steps=20, loss=0.
|
|
800
|
-
sphere_steps=10, sphere_loss=
|
|
818
|
+
func='booth', steps=20, loss=0.01, merge_invariant=True,
|
|
819
|
+
sphere_steps=10, sphere_loss=1,
|
|
801
820
|
)
|
|
802
821
|
|
|
803
822
|
# ---------------------------- smoothing/laplacian --------------------------- #
|
|
@@ -879,14 +898,14 @@ Adan = Run(
|
|
|
879
898
|
)
|
|
880
899
|
|
|
881
900
|
# ------------------------------------ CGs ----------------------------------- #
|
|
882
|
-
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.
|
|
901
|
+
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.DYHS, tz.m.ProjectedGradientMethod):
|
|
883
902
|
for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
|
|
884
903
|
# but also test 10 to make sure it doesn't explode after converging
|
|
885
904
|
Run(
|
|
886
905
|
func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
|
|
887
906
|
sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
|
|
888
907
|
needs_closure=True,
|
|
889
|
-
func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=
|
|
908
|
+
func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=True,
|
|
890
909
|
sphere_steps=sphere_steps_, sphere_loss=0,
|
|
891
910
|
)
|
|
892
911
|
|
|
@@ -917,10 +936,10 @@ for QN in (
|
|
|
917
936
|
tz.m.SSVM,
|
|
918
937
|
):
|
|
919
938
|
Run(
|
|
920
|
-
func_opt=lambda p: tz.Modular(p, QN(scale_first=False,
|
|
921
|
-
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False,
|
|
939
|
+
func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
|
|
940
|
+
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
|
|
922
941
|
needs_closure=True,
|
|
923
|
-
func='lstsq', steps=50, loss=1e-10, merge_invariant=
|
|
942
|
+
func='lstsq', steps=50, loss=1e-10, merge_invariant=True,
|
|
924
943
|
sphere_steps=10, sphere_loss=1e-20,
|
|
925
944
|
)
|
|
926
945
|
|
tests/test_tensorlist.py
CHANGED
|
@@ -977,22 +977,23 @@ def test_rademacher_like(big_tl: TensorList):
|
|
|
977
977
|
|
|
978
978
|
@pytest.mark.parametrize("dist", ['normal', 'uniform', 'sphere', 'rademacher'])
|
|
979
979
|
def test_sample_like(simple_tl: TensorList, dist):
|
|
980
|
-
eps_scalar =
|
|
981
|
-
result_tl_scalar = simple_tl.sample_like(
|
|
980
|
+
eps_scalar = 1
|
|
981
|
+
result_tl_scalar = simple_tl.sample_like(distribution=dist)
|
|
982
982
|
assert isinstance(result_tl_scalar, TensorList)
|
|
983
983
|
assert result_tl_scalar.shape == simple_tl.shape
|
|
984
984
|
|
|
985
|
-
eps_list = [
|
|
986
|
-
result_tl_list = simple_tl.sample_like(
|
|
985
|
+
eps_list = [1.0,]
|
|
986
|
+
result_tl_list = simple_tl.sample_like(distribution=dist)
|
|
987
987
|
assert isinstance(result_tl_list, TensorList)
|
|
988
988
|
assert result_tl_list.shape == simple_tl.shape
|
|
989
989
|
|
|
990
990
|
# Basic checks based on distribution
|
|
991
991
|
if dist == 'uniform':
|
|
992
|
-
assert all(torch.all((t >= -eps_scalar
|
|
993
|
-
assert all(torch.all((t >= -e
|
|
992
|
+
assert all(torch.all((t >= -eps_scalar) & (t <= eps_scalar)) for t in result_tl_scalar)
|
|
993
|
+
assert all(torch.all((t >= -e) & (t <= e)) for t, e in zip(result_tl_list, eps_list))
|
|
994
994
|
elif dist == 'sphere':
|
|
995
|
-
assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
|
|
995
|
+
# assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
|
|
996
|
+
pass
|
|
996
997
|
# Cannot check list version easily
|
|
997
998
|
elif dist == 'rademacher':
|
|
998
999
|
assert all(torch.all((t == -eps_scalar) | (t == eps_scalar)) for t in result_tl_scalar)
|
torchzero/__init__.py
CHANGED
torchzero/core/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from .module import
|
|
2
|
-
from .transform import
|
|
1
|
+
from .module import Chain, Chainable, Modular, Module, Var, maybe_chain
|
|
2
|
+
from .transform import Target, TensorwiseTransform, Transform, apply_transform
|