torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
tests/test_opts.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Sanity tests to make sure everything works.
|
|
3
|
+
|
|
4
|
+
This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
|
|
5
|
+
don't error or become unhinged with different parameter shapes.
|
|
6
|
+
"""
|
|
2
7
|
from collections.abc import Callable
|
|
3
8
|
from functools import partial
|
|
4
9
|
|
|
@@ -68,6 +73,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
68
73
|
assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
|
|
69
74
|
losses.append(loss)
|
|
70
75
|
|
|
76
|
+
losses.append(objective())
|
|
71
77
|
return torch.stack(losses).nan_to_num(0,10000,10000).min()
|
|
72
78
|
|
|
73
79
|
def _run_func(opt_fn: Callable, func:str, merge: bool, use_closure: bool, steps: int):
|
|
@@ -524,7 +530,7 @@ PolyakStepSize = Run(
|
|
|
524
530
|
func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
|
|
525
531
|
sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
|
|
526
532
|
needs_closure=True,
|
|
527
|
-
func='booth', steps=50, loss=1e-
|
|
533
|
+
func='booth', steps=50, loss=1e-7, merge_invariant=True,
|
|
528
534
|
sphere_steps=10, sphere_loss=0.002,
|
|
529
535
|
)
|
|
530
536
|
RandomStepSize = Run(
|
|
@@ -604,44 +610,44 @@ ScaleModulesByCosineSimilarity = Run(
|
|
|
604
610
|
|
|
605
611
|
# ------------------------- momentum/matrix_momentum ------------------------- #
|
|
606
612
|
MatrixMomentum_forward = Run(
|
|
607
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
608
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
613
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
|
|
614
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
|
|
609
615
|
needs_closure=True,
|
|
610
616
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
611
617
|
sphere_steps=10, sphere_loss=0,
|
|
612
618
|
)
|
|
613
619
|
MatrixMomentum_forward = Run(
|
|
614
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
615
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
620
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
|
|
621
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
|
|
616
622
|
needs_closure=True,
|
|
617
623
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
618
624
|
sphere_steps=10, sphere_loss=0,
|
|
619
625
|
)
|
|
620
626
|
MatrixMomentum_forward = Run(
|
|
621
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
622
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
627
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
|
|
628
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
|
|
623
629
|
needs_closure=True,
|
|
624
630
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
625
631
|
sphere_steps=10, sphere_loss=0,
|
|
626
632
|
)
|
|
627
633
|
|
|
628
634
|
AdaptiveMatrixMomentum_forward = Run(
|
|
629
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
630
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
635
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
|
|
636
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
|
|
631
637
|
needs_closure=True,
|
|
632
638
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
633
639
|
sphere_steps=10, sphere_loss=0,
|
|
634
640
|
)
|
|
635
641
|
AdaptiveMatrixMomentum_central = Run(
|
|
636
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
637
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
642
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
|
|
643
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
|
|
638
644
|
needs_closure=True,
|
|
639
645
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
640
646
|
sphere_steps=10, sphere_loss=0,
|
|
641
647
|
)
|
|
642
648
|
AdaptiveMatrixMomentum_autograd = Run(
|
|
643
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
644
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
649
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
|
|
650
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
|
|
645
651
|
needs_closure=True,
|
|
646
652
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
647
653
|
sphere_steps=10, sphere_loss=0,
|
|
@@ -719,11 +725,11 @@ Lion = Run(
|
|
|
719
725
|
)
|
|
720
726
|
# ---------------------------- optimizers/shampoo ---------------------------- #
|
|
721
727
|
Shampoo = Run(
|
|
722
|
-
func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(
|
|
723
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.
|
|
728
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
|
|
729
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
|
|
724
730
|
needs_closure=False,
|
|
725
|
-
func='booth', steps=50, loss=
|
|
726
|
-
sphere_steps=20, sphere_loss=
|
|
731
|
+
func='booth', steps=50, loss=0.02, merge_invariant=False,
|
|
732
|
+
sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
|
|
727
733
|
)
|
|
728
734
|
|
|
729
735
|
# ------------------------- quasi_newton/quasi_newton ------------------------ #
|
|
@@ -791,7 +797,7 @@ NewtonCG = Run(
|
|
|
791
797
|
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
|
|
792
798
|
needs_closure=True,
|
|
793
799
|
func='rosen', steps=20, loss=1e-7, merge_invariant=True,
|
|
794
|
-
sphere_steps=2, sphere_loss=
|
|
800
|
+
sphere_steps=2, sphere_loss=3e-4,
|
|
795
801
|
)
|
|
796
802
|
|
|
797
803
|
# ---------------------------- smoothing/gaussian ---------------------------- #
|
|
@@ -854,8 +860,17 @@ SophiaH = Run(
|
|
|
854
860
|
sphere_steps=10, sphere_loss=40,
|
|
855
861
|
)
|
|
856
862
|
|
|
863
|
+
# -------------------------- optimizers/higher_order ------------------------- #
|
|
864
|
+
HigherOrderNewton = Run(
|
|
865
|
+
func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
|
|
866
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
|
|
867
|
+
needs_closure=True,
|
|
868
|
+
func='rosen', steps=1, loss=2e-10, merge_invariant=True,
|
|
869
|
+
sphere_steps=1, sphere_loss=1e-10,
|
|
870
|
+
)
|
|
871
|
+
|
|
857
872
|
# ------------------------------------ CGs ----------------------------------- #
|
|
858
|
-
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY):
|
|
873
|
+
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
|
|
859
874
|
for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
|
|
860
875
|
# but also test 10 to make sure it doesn't explode after converging
|
|
861
876
|
Run(
|
|
@@ -868,7 +883,25 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
|
|
|
868
883
|
|
|
869
884
|
# ------------------------------- QN stability ------------------------------- #
|
|
870
885
|
# stability test
|
|
871
|
-
for QN in (
|
|
886
|
+
for QN in (
|
|
887
|
+
tz.m.BFGS,
|
|
888
|
+
tz.m.SR1,
|
|
889
|
+
tz.m.DFP,
|
|
890
|
+
tz.m.BroydenGood,
|
|
891
|
+
tz.m.BroydenBad,
|
|
892
|
+
tz.m.Greenstadt1,
|
|
893
|
+
tz.m.Greenstadt2,
|
|
894
|
+
tz.m.ColumnUpdatingMethod,
|
|
895
|
+
tz.m.ThomasOptimalMethod,
|
|
896
|
+
tz.m.FletcherVMM,
|
|
897
|
+
tz.m.Horisho,
|
|
898
|
+
lambda scale_first: tz.m.Horisho(scale_first=scale_first, inner=tz.m.GradientCorrection()),
|
|
899
|
+
tz.m.Pearson,
|
|
900
|
+
tz.m.ProjectedNewtonRaphson,
|
|
901
|
+
tz.m.PSB,
|
|
902
|
+
tz.m.McCormick,
|
|
903
|
+
tz.m.SSVM,
|
|
904
|
+
):
|
|
872
905
|
Run(
|
|
873
906
|
func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
874
907
|
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
tests/test_tensorlist.py
CHANGED
|
@@ -1261,8 +1261,8 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
|
|
|
1261
1261
|
elif reduction_method == 'quantile': expected = vec.quantile(q)
|
|
1262
1262
|
else:
|
|
1263
1263
|
pytest.fail("Unknown global reduction")
|
|
1264
|
-
assert False,
|
|
1265
|
-
assert torch.allclose(result, expected)
|
|
1264
|
+
assert False, reduction_method
|
|
1265
|
+
assert torch.allclose(result, expected, atol=1e-4)
|
|
1266
1266
|
else:
|
|
1267
1267
|
expected_list = []
|
|
1268
1268
|
for t in simple_tl:
|
tests/test_vars.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
import torch
|
|
3
|
-
from torchzero.core.module import
|
|
3
|
+
from torchzero.core.module import Var
|
|
4
4
|
from torchzero.utils.tensorlist import TensorList
|
|
5
5
|
|
|
6
6
|
@torch.no_grad
|
|
7
|
-
def
|
|
7
|
+
def test_var_get_loss():
|
|
8
8
|
|
|
9
9
|
# ---------------------------- test that it works ---------------------------- #
|
|
10
10
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
@@ -26,20 +26,20 @@ def test_vars_get_loss():
|
|
|
26
26
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
27
27
|
return loss
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
var = Var(params=params, closure=closure_1, model=None, current_step=0)
|
|
30
30
|
|
|
31
|
-
assert
|
|
31
|
+
assert var.loss is None, var.loss
|
|
32
32
|
|
|
33
|
-
assert (loss :=
|
|
33
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
34
34
|
assert evaluated, evaluated
|
|
35
|
-
assert loss is
|
|
36
|
-
assert
|
|
37
|
-
assert
|
|
38
|
-
assert
|
|
35
|
+
assert loss is var.loss
|
|
36
|
+
assert var.loss == 4.0
|
|
37
|
+
assert var.loss_approx == 4.0
|
|
38
|
+
assert var.grad is None, var.grad
|
|
39
39
|
|
|
40
40
|
# reevaluate, which should just return already evaluated loss
|
|
41
|
-
assert (loss :=
|
|
42
|
-
assert
|
|
41
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
42
|
+
assert var.grad is None, var.grad
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
# ----------------------- test that backward=True works ---------------------- #
|
|
@@ -61,30 +61,30 @@ def test_vars_get_loss():
|
|
|
61
61
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
62
62
|
return loss
|
|
63
63
|
|
|
64
|
-
|
|
65
|
-
assert
|
|
66
|
-
assert (loss :=
|
|
67
|
-
assert
|
|
68
|
-
assert
|
|
64
|
+
var = Var(params=params, closure=closure_2, model=None, current_step=0)
|
|
65
|
+
assert var.grad is None, var.grad
|
|
66
|
+
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
67
|
+
assert var.grad is not None
|
|
68
|
+
assert var.grad[0] == 2.0, var.grad
|
|
69
69
|
|
|
70
70
|
# reevaluate, which should just return already evaluated loss
|
|
71
|
-
assert (loss :=
|
|
72
|
-
assert
|
|
71
|
+
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
72
|
+
assert var.grad[0] == 2.0, var.grad
|
|
73
73
|
|
|
74
74
|
# get grad, which should just return already evaluated grad
|
|
75
|
-
assert (grad :=
|
|
76
|
-
assert grad is
|
|
75
|
+
assert (grad := var.get_grad())[0] == 2.0, grad
|
|
76
|
+
assert grad is var.grad, grad
|
|
77
77
|
|
|
78
78
|
# get update, which should create and return cloned grad
|
|
79
|
-
assert
|
|
80
|
-
assert (update :=
|
|
81
|
-
assert update is
|
|
82
|
-
assert update is not
|
|
83
|
-
assert
|
|
84
|
-
assert update[0] ==
|
|
79
|
+
assert var.update is None
|
|
80
|
+
assert (update := var.get_update())[0] == 2.0, update
|
|
81
|
+
assert update is var.update
|
|
82
|
+
assert update is not var.grad
|
|
83
|
+
assert var.grad is not None
|
|
84
|
+
assert update[0] == var.grad[0]
|
|
85
85
|
|
|
86
86
|
@torch.no_grad
|
|
87
|
-
def
|
|
87
|
+
def test_var_get_grad():
|
|
88
88
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
89
89
|
evaluated = False
|
|
90
90
|
|
|
@@ -103,20 +103,20 @@ def test_vars_get_grad():
|
|
|
103
103
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
104
104
|
return loss
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
assert (grad :=
|
|
108
|
-
assert grad is
|
|
106
|
+
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
107
|
+
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
108
|
+
assert grad is var.grad
|
|
109
109
|
|
|
110
|
-
assert
|
|
111
|
-
assert (loss :=
|
|
112
|
-
assert (loss :=
|
|
113
|
-
assert
|
|
110
|
+
assert var.loss == 4.0
|
|
111
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
112
|
+
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
113
|
+
assert var.loss_approx == 4.0
|
|
114
114
|
|
|
115
|
-
assert
|
|
116
|
-
assert (update :=
|
|
115
|
+
assert var.update is None, var.update
|
|
116
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
117
117
|
|
|
118
118
|
@torch.no_grad
|
|
119
|
-
def
|
|
119
|
+
def test_var_get_update():
|
|
120
120
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
121
121
|
evaluated = False
|
|
122
122
|
|
|
@@ -135,24 +135,24 @@ def test_vars_get_update():
|
|
|
135
135
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
136
136
|
return loss
|
|
137
137
|
|
|
138
|
-
|
|
139
|
-
assert
|
|
140
|
-
assert (update :=
|
|
141
|
-
assert update is
|
|
138
|
+
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
139
|
+
assert var.update is None, var.update
|
|
140
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
141
|
+
assert update is var.update
|
|
142
142
|
|
|
143
|
-
assert (grad :=
|
|
144
|
-
assert grad is
|
|
143
|
+
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
144
|
+
assert grad is var.grad
|
|
145
145
|
assert grad is not update
|
|
146
146
|
|
|
147
|
-
assert
|
|
148
|
-
assert (loss :=
|
|
149
|
-
assert (loss :=
|
|
150
|
-
assert
|
|
147
|
+
assert var.loss == 4.0
|
|
148
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
149
|
+
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
150
|
+
assert var.loss_approx == 4.0
|
|
151
151
|
|
|
152
|
-
assert (update :=
|
|
152
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
153
153
|
|
|
154
154
|
|
|
155
|
-
def
|
|
155
|
+
def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
|
|
156
156
|
for k,v in v1.__dict__.items():
|
|
157
157
|
if not k.startswith('__'):
|
|
158
158
|
# if k == 'post_step_hooks': continue
|
|
@@ -165,20 +165,20 @@ def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
|
|
|
165
165
|
else:
|
|
166
166
|
assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
167
167
|
|
|
168
|
-
def
|
|
168
|
+
def test_var_clone():
|
|
169
169
|
model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
|
|
170
170
|
def closure(backward): return 1
|
|
171
|
-
|
|
171
|
+
var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
|
|
172
172
|
|
|
173
|
-
|
|
174
|
-
|
|
173
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
174
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
175
175
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
176
|
+
var.grad = TensorList(torch.randn(5))
|
|
177
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
178
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
179
179
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
180
|
+
var.update = TensorList(torch.randn(5) * 2)
|
|
181
|
+
var.loss = torch.randn(1)
|
|
182
|
+
var.loss_approx = var.loss
|
|
183
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
184
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
torchzero/core/__init__.py
CHANGED
|
@@ -1,3 +1,2 @@
|
|
|
1
|
-
from .module import
|
|
2
|
-
from .transform import Transform, TensorwiseTransform, Target,
|
|
3
|
-
from .preconditioner import Preconditioner, TensorwisePreconditioner
|
|
1
|
+
from .module import Var, Module, Modular, Chain, maybe_chain, Chainable
|
|
2
|
+
from .transform import Transform, TensorwiseTransform, Target, apply_transform
|
torchzero/core/module.py
CHANGED
|
@@ -29,8 +29,8 @@ def _closure_backward(closure, params, retain_graph, create_graph):
|
|
|
29
29
|
return loss
|
|
30
30
|
|
|
31
31
|
# region Vars
|
|
32
|
-
# -----------------------------------
|
|
33
|
-
class
|
|
32
|
+
# ----------------------------------- var ----------------------------------- #
|
|
33
|
+
class Var:
|
|
34
34
|
"""
|
|
35
35
|
Holds the state and context passed between optimizer modules during a step.
|
|
36
36
|
|
|
@@ -74,13 +74,13 @@ class Vars:
|
|
|
74
74
|
"""loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
|
|
75
75
|
whereas some other modules require loss strictly at current point."""
|
|
76
76
|
|
|
77
|
-
self.post_step_hooks: list[Callable[[Modular,
|
|
77
|
+
self.post_step_hooks: list[Callable[[Modular, Var]]] = []
|
|
78
78
|
"""list of functions to be called after optimizer step.
|
|
79
79
|
The signature is:
|
|
80
80
|
|
|
81
81
|
.. code:: py
|
|
82
82
|
|
|
83
|
-
def hook(optimizer: Modular,
|
|
83
|
+
def hook(optimizer: Modular, var: Vars): ...
|
|
84
84
|
|
|
85
85
|
"""
|
|
86
86
|
|
|
@@ -110,7 +110,7 @@ class Vars:
|
|
|
110
110
|
"""if True, the parameters will not be updated"""
|
|
111
111
|
|
|
112
112
|
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
113
|
-
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`
|
|
113
|
+
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
|
|
114
114
|
Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
|
|
115
115
|
|
|
116
116
|
if self.loss is None:
|
|
@@ -143,7 +143,7 @@ class Vars:
|
|
|
143
143
|
|
|
144
144
|
def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
|
|
145
145
|
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
|
|
146
|
-
:code:`
|
|
146
|
+
:code:`var.grad` and potentially :code:`var.loss`. Do not call this at perturbed parameters."""
|
|
147
147
|
if self.grad is None:
|
|
148
148
|
if self.closure is None: raise RuntimeError("closure is None")
|
|
149
149
|
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
@@ -152,15 +152,15 @@ class Vars:
|
|
|
152
152
|
return self.grad
|
|
153
153
|
|
|
154
154
|
def get_update(self) -> list[torch.Tensor]:
|
|
155
|
-
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`
|
|
156
|
-
Computing the gradients may assign :code:`
|
|
155
|
+
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`var.update`.
|
|
156
|
+
Computing the gradients may assign :code:`var.grad` and :code:`var.loss` if they haven't been computed.
|
|
157
157
|
Do not call this at perturbed parameters."""
|
|
158
158
|
if self.update is None: self.update = [g.clone() for g in self.get_grad()]
|
|
159
159
|
return self.update
|
|
160
160
|
|
|
161
161
|
def clone(self, clone_update: bool):
|
|
162
162
|
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
|
|
163
|
-
copy =
|
|
163
|
+
copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
|
|
164
164
|
|
|
165
165
|
if clone_update and self.update is not None:
|
|
166
166
|
copy.update = [u.clone() for u in self.update]
|
|
@@ -176,16 +176,16 @@ class Vars:
|
|
|
176
176
|
|
|
177
177
|
return copy
|
|
178
178
|
|
|
179
|
-
def update_attrs_from_clone_(self,
|
|
179
|
+
def update_attrs_from_clone_(self, var: "Var"):
|
|
180
180
|
"""Updates attributes of this `Vars` instance from a cloned instance.
|
|
181
181
|
Typically called after a child module has processed a cloned `Vars`
|
|
182
182
|
object. This propagates any newly computed loss or gradient values
|
|
183
183
|
from the child's context back to the parent `Vars` if the parent
|
|
184
184
|
didn't have them computed already.
|
|
185
185
|
"""
|
|
186
|
-
if self.loss is None: self.loss =
|
|
187
|
-
if self.loss_approx is None: self.loss_approx =
|
|
188
|
-
if self.grad is None: self.grad =
|
|
186
|
+
if self.loss is None: self.loss = var.loss
|
|
187
|
+
if self.loss_approx is None: self.loss_approx = var.loss_approx
|
|
188
|
+
if self.grad is None: self.grad = var.grad
|
|
189
189
|
|
|
190
190
|
def zero_grad(self, set_to_none=True):
|
|
191
191
|
if set_to_none:
|
|
@@ -269,36 +269,36 @@ class Module(ABC):
|
|
|
269
269
|
return s
|
|
270
270
|
|
|
271
271
|
@overload
|
|
272
|
-
def get_settings(self, key: str, *,
|
|
273
|
-
|
|
272
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: str, *,
|
|
273
|
+
cls: type[ListLike] = list) -> ListLike: ...
|
|
274
274
|
@overload
|
|
275
|
-
def get_settings(self, key: list[str] | tuple[str,...], *,
|
|
276
|
-
|
|
275
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
|
|
276
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
277
277
|
@overload
|
|
278
|
-
def get_settings(self, key: str, key2: str, *keys: str,
|
|
279
|
-
|
|
278
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
|
|
279
|
+
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
280
280
|
|
|
281
|
-
def get_settings(self, key: str | list[str] | tuple[str,...], key2: str | None = None,
|
|
282
|
-
|
|
281
|
+
def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
|
|
282
|
+
*keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
283
283
|
# if isinstance(params, Vars): params = params.params
|
|
284
284
|
return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
|
|
285
285
|
|
|
286
286
|
|
|
287
287
|
@overload
|
|
288
|
-
def get_state(self, key: str, *,
|
|
289
|
-
|
|
288
|
+
def get_state(self, params: Sequence[torch.Tensor], key: str, *,
|
|
289
|
+
must_exist: bool = False, init: Init = torch.zeros_like,
|
|
290
290
|
cls: type[ListLike] = list) -> ListLike: ...
|
|
291
291
|
@overload
|
|
292
|
-
def get_state(self, key: list[str] | tuple[str,...], *,
|
|
293
|
-
|
|
292
|
+
def get_state(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
|
|
293
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
294
294
|
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
295
295
|
@overload
|
|
296
|
-
def get_state(self, key: str, key2: str, *keys: str,
|
|
297
|
-
|
|
296
|
+
def get_state(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
|
|
297
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
298
298
|
cls: type[ListLike] = list) -> list[ListLike]: ...
|
|
299
299
|
|
|
300
|
-
def get_state(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
301
|
-
|
|
300
|
+
def get_state(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
|
|
301
|
+
must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
|
|
302
302
|
cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
303
303
|
"""Returns values of per-parameter state for a given key.
|
|
304
304
|
If key doesn't exist, create it with inits.
|
|
@@ -404,8 +404,8 @@ class Module(ABC):
|
|
|
404
404
|
|
|
405
405
|
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
406
406
|
@abstractmethod
|
|
407
|
-
def step(self,
|
|
408
|
-
"""performs a step, returns new
|
|
407
|
+
def step(self, var: Var) -> Var:
|
|
408
|
+
"""performs a step, returns new var but may update them in-place."""
|
|
409
409
|
|
|
410
410
|
def reset(self):
|
|
411
411
|
"""Resets the internal state of the module (e.g. momentum)."""
|
|
@@ -556,13 +556,13 @@ class Modular(torch.optim.Optimizer):
|
|
|
556
556
|
if not p.requires_grad: continue
|
|
557
557
|
for map in self._per_parameter_global_settings[p]: map.update(settings)
|
|
558
558
|
|
|
559
|
-
# create
|
|
559
|
+
# create var
|
|
560
560
|
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
561
|
-
|
|
561
|
+
var = Var(params=params, closure=closure, model=self.model, current_step=self.current_step)
|
|
562
562
|
|
|
563
563
|
# if closure is None, assume backward has been called and gather grads
|
|
564
564
|
if closure is None:
|
|
565
|
-
|
|
565
|
+
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
566
566
|
|
|
567
567
|
last_module = self.modules[-1]
|
|
568
568
|
last_lr = last_module.defaults.get('lr', None)
|
|
@@ -570,27 +570,27 @@ class Modular(torch.optim.Optimizer):
|
|
|
570
570
|
|
|
571
571
|
# step
|
|
572
572
|
for i, module in enumerate(self.modules):
|
|
573
|
-
if i!=0:
|
|
573
|
+
if i!=0: var = var.clone(clone_update=False)
|
|
574
574
|
|
|
575
575
|
# last module, or next to last module before lr
|
|
576
576
|
if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
|
|
577
|
-
if module.children:
|
|
578
|
-
else:
|
|
579
|
-
if last_lr is not None:
|
|
577
|
+
if module.children: var.nested_is_last = True
|
|
578
|
+
else: var.is_last = True
|
|
579
|
+
if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
|
|
580
580
|
|
|
581
|
-
|
|
582
|
-
if
|
|
581
|
+
var = module.step(var)
|
|
582
|
+
if var.stop: break
|
|
583
583
|
|
|
584
584
|
# apply update
|
|
585
|
-
if not
|
|
585
|
+
if not var.skip_update:
|
|
586
586
|
with torch.no_grad():
|
|
587
|
-
torch._foreach_sub_(params,
|
|
587
|
+
torch._foreach_sub_(params, var.get_update())
|
|
588
588
|
|
|
589
|
-
for hook in
|
|
590
|
-
hook(self,
|
|
589
|
+
for hook in var.post_step_hooks:
|
|
590
|
+
hook(self, var)
|
|
591
591
|
|
|
592
592
|
self.current_step += 1
|
|
593
|
-
return
|
|
593
|
+
return var.loss if var.loss is not None else var.loss_approx
|
|
594
594
|
|
|
595
595
|
def __repr__(self):
|
|
596
596
|
return f'Modular({", ".join(str(m) for m in self.modules)})'
|
|
@@ -606,11 +606,11 @@ class Chain(Module):
|
|
|
606
606
|
for i, module in enumerate(flat_modules):
|
|
607
607
|
self.set_child(f'module_{i}', module)
|
|
608
608
|
|
|
609
|
-
def step(self,
|
|
609
|
+
def step(self, var):
|
|
610
610
|
for i in range(len(self.children)):
|
|
611
|
-
|
|
612
|
-
if
|
|
613
|
-
return
|
|
611
|
+
var = self.children[f'module_{i}'].step(var)
|
|
612
|
+
if var.stop: break
|
|
613
|
+
return var
|
|
614
614
|
|
|
615
615
|
def __repr__(self):
|
|
616
616
|
s = self.__class__.__name__
|