torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
tests/test_opts.py CHANGED
@@ -1,4 +1,9 @@
1
- """snity tests to make sure everything works and converges on basic functions"""
1
+ """
2
+ Sanity tests to make sure everything works.
3
+
4
+ This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
5
+ don't error or become unhinged with different parameter shapes.
6
+ """
2
7
  from collections.abc import Callable
3
8
  from functools import partial
4
9
 
@@ -68,6 +73,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
68
73
  assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
69
74
  losses.append(loss)
70
75
 
76
+ losses.append(objective())
71
77
  return torch.stack(losses).nan_to_num(0,10000,10000).min()
72
78
 
73
79
  def _run_func(opt_fn: Callable, func:str, merge: bool, use_closure: bool, steps: int):
@@ -524,7 +530,7 @@ PolyakStepSize = Run(
524
530
  func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
525
531
  sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
526
532
  needs_closure=True,
527
- func='booth', steps=50, loss=1e-11, merge_invariant=True,
533
+ func='booth', steps=50, loss=1e-7, merge_invariant=True,
528
534
  sphere_steps=10, sphere_loss=0.002,
529
535
  )
530
536
  RandomStepSize = Run(
@@ -604,44 +610,44 @@ ScaleModulesByCosineSimilarity = Run(
604
610
 
605
611
  # ------------------------- momentum/matrix_momentum ------------------------- #
606
612
  MatrixMomentum_forward = Run(
607
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.01)),
608
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
613
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
614
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
609
615
  needs_closure=True,
610
616
  func='booth', steps=50, loss=0.05, merge_invariant=True,
611
617
  sphere_steps=10, sphere_loss=0,
612
618
  )
613
619
  MatrixMomentum_forward = Run(
614
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.01)),
615
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
620
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
621
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
616
622
  needs_closure=True,
617
623
  func='booth', steps=50, loss=0.05, merge_invariant=True,
618
624
  sphere_steps=10, sphere_loss=0,
619
625
  )
620
626
  MatrixMomentum_forward = Run(
621
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.01)),
622
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
627
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
628
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
623
629
  needs_closure=True,
624
630
  func='booth', steps=50, loss=0.05, merge_invariant=True,
625
631
  sphere_steps=10, sphere_loss=0,
626
632
  )
627
633
 
628
634
  AdaptiveMatrixMomentum_forward = Run(
629
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.05)),
630
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
635
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
636
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
631
637
  needs_closure=True,
632
638
  func='booth', steps=50, loss=0.002, merge_invariant=True,
633
639
  sphere_steps=10, sphere_loss=0,
634
640
  )
635
641
  AdaptiveMatrixMomentum_central = Run(
636
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.05)),
637
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
642
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
643
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
638
644
  needs_closure=True,
639
645
  func='booth', steps=50, loss=0.002, merge_invariant=True,
640
646
  sphere_steps=10, sphere_loss=0,
641
647
  )
642
648
  AdaptiveMatrixMomentum_autograd = Run(
643
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.05)),
644
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
649
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
650
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
645
651
  needs_closure=True,
646
652
  func='booth', steps=50, loss=0.002, merge_invariant=True,
647
653
  sphere_steps=10, sphere_loss=0,
@@ -719,11 +725,11 @@ Lion = Run(
719
725
  )
720
726
  # ---------------------------- optimizers/shampoo ---------------------------- #
721
727
  Shampoo = Run(
722
- func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
723
- sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.2)),
728
+ func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
729
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
724
730
  needs_closure=False,
725
- func='booth', steps=50, loss=200, merge_invariant=False,
726
- sphere_steps=20, sphere_loss=1e-3, # merge and unmerge lrs are very different so need to test convergence separately somewhere
731
+ func='booth', steps=50, loss=0.02, merge_invariant=False,
732
+ sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
727
733
  )
728
734
 
729
735
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
@@ -745,7 +751,7 @@ SSVM = Run(
745
751
  func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
752
  sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
753
  needs_closure=True,
748
- func='rosen', steps=50, loss=0.02, merge_invariant=True,
754
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
749
755
  sphere_steps=10, sphere_loss=0,
750
756
  )
751
757
 
@@ -791,7 +797,7 @@ NewtonCG = Run(
791
797
  sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
792
798
  needs_closure=True,
793
799
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
794
- sphere_steps=2, sphere_loss=1e-6,
800
+ sphere_steps=2, sphere_loss=3e-4,
795
801
  )
796
802
 
797
803
  # ---------------------------- smoothing/gaussian ---------------------------- #
@@ -854,8 +860,17 @@ SophiaH = Run(
854
860
  sphere_steps=10, sphere_loss=40,
855
861
  )
856
862
 
863
+ # -------------------------- optimizers/higher_order ------------------------- #
864
+ HigherOrderNewton = Run(
865
+ func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
866
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
867
+ needs_closure=True,
868
+ func='rosen', steps=1, loss=2e-10, merge_invariant=True,
869
+ sphere_steps=1, sphere_loss=1e-10,
870
+ )
871
+
857
872
  # ------------------------------------ CGs ----------------------------------- #
858
- for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY):
873
+ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
859
874
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
860
875
  # but also test 10 to make sure it doesn't explode after converging
861
876
  Run(
@@ -868,7 +883,25 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
868
883
 
869
884
  # ------------------------------- QN stability ------------------------------- #
870
885
  # stability test
871
- for QN in (tz.m.BFGS, tz.m.SR1, tz.m.DFP, tz.m.BroydenGood, tz.m.BroydenBad, tz.m.Greenstadt1, tz.m.Greenstadt2, tz.m.ColumnUpdatingMethod, tz.m.ThomasOptimalMethod, tz.m.PSB, tz.m.Pearson2, tz.m.SSVM):
886
+ for QN in (
887
+ tz.m.BFGS,
888
+ tz.m.SR1,
889
+ tz.m.DFP,
890
+ tz.m.BroydenGood,
891
+ tz.m.BroydenBad,
892
+ tz.m.Greenstadt1,
893
+ tz.m.Greenstadt2,
894
+ tz.m.ColumnUpdatingMethod,
895
+ tz.m.ThomasOptimalMethod,
896
+ tz.m.FletcherVMM,
897
+ tz.m.Horisho,
898
+ lambda scale_first: tz.m.Horisho(scale_first=scale_first, inner=tz.m.GradientCorrection()),
899
+ tz.m.Pearson,
900
+ tz.m.ProjectedNewtonRaphson,
901
+ tz.m.PSB,
902
+ tz.m.McCormick,
903
+ tz.m.SSVM,
904
+ ):
872
905
  Run(
873
906
  func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
874
907
  sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
tests/test_tensorlist.py CHANGED
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
835
835
  expected = vec_equiv_func()
836
836
 
837
837
  if isinstance(result, bool): assert result == expected
838
- else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
838
+ else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
839
839
 
840
840
 
841
841
  def test_global_vector_norm(simple_tl: TensorList):
@@ -1261,8 +1261,8 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
1261
1261
  elif reduction_method == 'quantile': expected = vec.quantile(q)
1262
1262
  else:
1263
1263
  pytest.fail("Unknown global reduction")
1264
- assert False, 'sus'
1265
- assert torch.allclose(result, expected)
1264
+ assert False, reduction_method
1265
+ assert torch.allclose(result, expected, atol=1e-4)
1266
1266
  else:
1267
1267
  expected_list = []
1268
1268
  for t in simple_tl:
tests/test_vars.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import pytest
2
2
  import torch
3
- from torchzero.core.module import Vars
3
+ from torchzero.core.module import Var
4
4
  from torchzero.utils.tensorlist import TensorList
5
5
 
6
6
  @torch.no_grad
7
- def test_vars_get_loss():
7
+ def test_var_get_loss():
8
8
 
9
9
  # ---------------------------- test that it works ---------------------------- #
10
10
  params = [torch.tensor(2.0, requires_grad=True)]
@@ -26,20 +26,20 @@ def test_vars_get_loss():
26
26
  assert not loss.requires_grad, "loss requires grad with backward=False"
27
27
  return loss
28
28
 
29
- vars = Vars(params=params, closure=closure_1, model=None, current_step=0)
29
+ var = Var(params=params, closure=closure_1, model=None, current_step=0)
30
30
 
31
- assert vars.loss is None, vars.loss
31
+ assert var.loss is None, var.loss
32
32
 
33
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
33
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
34
34
  assert evaluated, evaluated
35
- assert loss is vars.loss
36
- assert vars.loss == 4.0
37
- assert vars.loss_approx == 4.0
38
- assert vars.grad is None, vars.grad
35
+ assert loss is var.loss
36
+ assert var.loss == 4.0
37
+ assert var.loss_approx == 4.0
38
+ assert var.grad is None, var.grad
39
39
 
40
40
  # reevaluate, which should just return already evaluated loss
41
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
42
- assert vars.grad is None, vars.grad
41
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
42
+ assert var.grad is None, var.grad
43
43
 
44
44
 
45
45
  # ----------------------- test that backward=True works ---------------------- #
@@ -61,30 +61,30 @@ def test_vars_get_loss():
61
61
  assert not loss.requires_grad, "loss requires grad with backward=False"
62
62
  return loss
63
63
 
64
- vars = Vars(params=params, closure=closure_2, model=None, current_step=0)
65
- assert vars.grad is None, vars.grad
66
- assert (loss := vars.get_loss(backward=True)) == 6.0, loss
67
- assert vars.grad is not None
68
- assert vars.grad[0] == 2.0, vars.grad
64
+ var = Var(params=params, closure=closure_2, model=None, current_step=0)
65
+ assert var.grad is None, var.grad
66
+ assert (loss := var.get_loss(backward=True)) == 6.0, loss
67
+ assert var.grad is not None
68
+ assert var.grad[0] == 2.0, var.grad
69
69
 
70
70
  # reevaluate, which should just return already evaluated loss
71
- assert (loss := vars.get_loss(backward=True)) == 6.0, loss
72
- assert vars.grad[0] == 2.0, vars.grad
71
+ assert (loss := var.get_loss(backward=True)) == 6.0, loss
72
+ assert var.grad[0] == 2.0, var.grad
73
73
 
74
74
  # get grad, which should just return already evaluated grad
75
- assert (grad := vars.get_grad())[0] == 2.0, grad
76
- assert grad is vars.grad, grad
75
+ assert (grad := var.get_grad())[0] == 2.0, grad
76
+ assert grad is var.grad, grad
77
77
 
78
78
  # get update, which should create and return cloned grad
79
- assert vars.update is None
80
- assert (update := vars.get_update())[0] == 2.0, update
81
- assert update is vars.update
82
- assert update is not vars.grad
83
- assert vars.grad is not None
84
- assert update[0] == vars.grad[0]
79
+ assert var.update is None
80
+ assert (update := var.get_update())[0] == 2.0, update
81
+ assert update is var.update
82
+ assert update is not var.grad
83
+ assert var.grad is not None
84
+ assert update[0] == var.grad[0]
85
85
 
86
86
  @torch.no_grad
87
- def test_vars_get_grad():
87
+ def test_var_get_grad():
88
88
  params = [torch.tensor(2.0, requires_grad=True)]
89
89
  evaluated = False
90
90
 
@@ -103,20 +103,20 @@ def test_vars_get_grad():
103
103
  assert not loss.requires_grad, "loss requires grad with backward=False"
104
104
  return loss
105
105
 
106
- vars = Vars(params=params, closure=closure, model=None, current_step=0)
107
- assert (grad := vars.get_grad())[0] == 4.0, grad
108
- assert grad is vars.grad
106
+ var = Var(params=params, closure=closure, model=None, current_step=0)
107
+ assert (grad := var.get_grad())[0] == 4.0, grad
108
+ assert grad is var.grad
109
109
 
110
- assert vars.loss == 4.0
111
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
112
- assert (loss := vars.get_loss(backward=True)) == 4.0, loss
113
- assert vars.loss_approx == 4.0
110
+ assert var.loss == 4.0
111
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
112
+ assert (loss := var.get_loss(backward=True)) == 4.0, loss
113
+ assert var.loss_approx == 4.0
114
114
 
115
- assert vars.update is None, vars.update
116
- assert (update := vars.get_update())[0] == 4.0, update
115
+ assert var.update is None, var.update
116
+ assert (update := var.get_update())[0] == 4.0, update
117
117
 
118
118
  @torch.no_grad
119
- def test_vars_get_update():
119
+ def test_var_get_update():
120
120
  params = [torch.tensor(2.0, requires_grad=True)]
121
121
  evaluated = False
122
122
 
@@ -135,24 +135,24 @@ def test_vars_get_update():
135
135
  assert not loss.requires_grad, "loss requires grad with backward=False"
136
136
  return loss
137
137
 
138
- vars = Vars(params=params, closure=closure, model=None, current_step=0)
139
- assert vars.update is None, vars.update
140
- assert (update := vars.get_update())[0] == 4.0, update
141
- assert update is vars.update
138
+ var = Var(params=params, closure=closure, model=None, current_step=0)
139
+ assert var.update is None, var.update
140
+ assert (update := var.get_update())[0] == 4.0, update
141
+ assert update is var.update
142
142
 
143
- assert (grad := vars.get_grad())[0] == 4.0, grad
144
- assert grad is vars.grad
143
+ assert (grad := var.get_grad())[0] == 4.0, grad
144
+ assert grad is var.grad
145
145
  assert grad is not update
146
146
 
147
- assert vars.loss == 4.0
148
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
149
- assert (loss := vars.get_loss(backward=True)) == 4.0, loss
150
- assert vars.loss_approx == 4.0
147
+ assert var.loss == 4.0
148
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
149
+ assert (loss := var.get_loss(backward=True)) == 4.0, loss
150
+ assert var.loss_approx == 4.0
151
151
 
152
- assert (update := vars.get_update())[0] == 4.0, update
152
+ assert (update := var.get_update())[0] == 4.0, update
153
153
 
154
154
 
155
- def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
155
+ def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
156
156
  for k,v in v1.__dict__.items():
157
157
  if not k.startswith('__'):
158
158
  # if k == 'post_step_hooks': continue
@@ -165,20 +165,20 @@ def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
165
165
  else:
166
166
  assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
167
167
 
168
- def test_vars_clone():
168
+ def test_var_clone():
169
169
  model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
170
170
  def closure(backward): return 1
171
- vars = Vars(params=list(model.parameters()), closure=closure, model=model, current_step=0)
171
+ var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
172
172
 
173
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
174
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
173
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
174
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
175
175
 
176
- vars.grad = TensorList(torch.randn(5))
177
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
178
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
176
+ var.grad = TensorList(torch.randn(5))
177
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
178
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
179
179
 
180
- vars.update = TensorList(torch.randn(5) * 2)
181
- vars.loss = torch.randn(1)
182
- vars.loss_approx = vars.loss
183
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
184
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
180
+ var.update = TensorList(torch.randn(5) * 2)
181
+ var.loss = torch.randn(1)
182
+ var.loss_approx = var.loss
183
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
184
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
@@ -1,3 +1,2 @@
1
- from .module import Vars, Module, Modular, Chain, maybe_chain, Chainable
2
- from .transform import Transform, TensorwiseTransform, Target, apply
3
- from .preconditioner import Preconditioner, TensorwisePreconditioner
1
+ from .module import Var, Module, Modular, Chain, maybe_chain, Chainable
2
+ from .transform import Transform, TensorwiseTransform, Target, apply_transform
torchzero/core/module.py CHANGED
@@ -29,8 +29,8 @@ def _closure_backward(closure, params, retain_graph, create_graph):
29
29
  return loss
30
30
 
31
31
  # region Vars
32
- # ----------------------------------- vars ----------------------------------- #
33
- class Vars:
32
+ # ----------------------------------- var ----------------------------------- #
33
+ class Var:
34
34
  """
35
35
  Holds the state and context passed between optimizer modules during a step.
36
36
 
@@ -74,13 +74,13 @@ class Vars:
74
74
  """loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
75
75
  whereas some other modules require loss strictly at current point."""
76
76
 
77
- self.post_step_hooks: list[Callable[[Modular, Vars]]] = []
77
+ self.post_step_hooks: list[Callable[[Modular, Var]]] = []
78
78
  """list of functions to be called after optimizer step.
79
79
  The signature is:
80
80
 
81
81
  .. code:: py
82
82
 
83
- def hook(optimizer: Modular, vars: Vars): ...
83
+ def hook(optimizer: Modular, var: Vars): ...
84
84
 
85
85
  """
86
86
 
@@ -110,7 +110,7 @@ class Vars:
110
110
  """if True, the parameters will not be updated"""
111
111
 
112
112
  def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
113
- """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`vars.loss`.
113
+ """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
114
114
  Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
115
115
 
116
116
  if self.loss is None:
@@ -143,7 +143,7 @@ class Vars:
143
143
 
144
144
  def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
145
145
  """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
146
- :code:`vars.grad` and potentially :code:`vars.loss`. Do not call this at perturbed parameters."""
146
+ :code:`var.grad` and potentially :code:`var.loss`. Do not call this at perturbed parameters."""
147
147
  if self.grad is None:
148
148
  if self.closure is None: raise RuntimeError("closure is None")
149
149
  self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
@@ -152,15 +152,15 @@ class Vars:
152
152
  return self.grad
153
153
 
154
154
  def get_update(self) -> list[torch.Tensor]:
155
- """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`vars.update`.
156
- Computing the gradients may assign :code:`vars.grad` and :code:`vars.loss` if they haven't been computed.
155
+ """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`var.update`.
156
+ Computing the gradients may assign :code:`var.grad` and :code:`var.loss` if they haven't been computed.
157
157
  Do not call this at perturbed parameters."""
158
158
  if self.update is None: self.update = [g.clone() for g in self.get_grad()]
159
159
  return self.update
160
160
 
161
161
  def clone(self, clone_update: bool):
162
162
  """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
163
- copy = Vars(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
163
+ copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
164
164
 
165
165
  if clone_update and self.update is not None:
166
166
  copy.update = [u.clone() for u in self.update]
@@ -176,16 +176,16 @@ class Vars:
176
176
 
177
177
  return copy
178
178
 
179
- def update_attrs_from_clone_(self, vars: "Vars"):
179
+ def update_attrs_from_clone_(self, var: "Var"):
180
180
  """Updates attributes of this `Vars` instance from a cloned instance.
181
181
  Typically called after a child module has processed a cloned `Vars`
182
182
  object. This propagates any newly computed loss or gradient values
183
183
  from the child's context back to the parent `Vars` if the parent
184
184
  didn't have them computed already.
185
185
  """
186
- if self.loss is None: self.loss = vars.loss
187
- if self.loss_approx is None: self.loss_approx = vars.loss_approx
188
- if self.grad is None: self.grad = vars.grad
186
+ if self.loss is None: self.loss = var.loss
187
+ if self.loss_approx is None: self.loss_approx = var.loss_approx
188
+ if self.grad is None: self.grad = var.grad
189
189
 
190
190
  def zero_grad(self, set_to_none=True):
191
191
  if set_to_none:
@@ -269,36 +269,36 @@ class Module(ABC):
269
269
  return s
270
270
 
271
271
  @overload
272
- def get_settings(self, key: str, *,
273
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike: ...
272
+ def get_settings(self, params: Sequence[torch.Tensor], key: str, *,
273
+ cls: type[ListLike] = list) -> ListLike: ...
274
274
  @overload
275
- def get_settings(self, key: list[str] | tuple[str,...], *,
276
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
275
+ def get_settings(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
276
+ cls: type[ListLike] = list) -> list[ListLike]: ...
277
277
  @overload
278
- def get_settings(self, key: str, key2: str, *keys: str,
279
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
278
+ def get_settings(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
279
+ cls: type[ListLike] = list) -> list[ListLike]: ...
280
280
 
281
- def get_settings(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
282
- params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike | list[ListLike]:
281
+ def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
282
+ *keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
283
283
  # if isinstance(params, Vars): params = params.params
284
284
  return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
285
285
 
286
286
 
287
287
  @overload
288
- def get_state(self, key: str, *,
289
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init = torch.zeros_like,
288
+ def get_state(self, params: Sequence[torch.Tensor], key: str, *,
289
+ must_exist: bool = False, init: Init = torch.zeros_like,
290
290
  cls: type[ListLike] = list) -> ListLike: ...
291
291
  @overload
292
- def get_state(self, key: list[str] | tuple[str,...], *,
293
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
292
+ def get_state(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
293
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
294
294
  cls: type[ListLike] = list) -> list[ListLike]: ...
295
295
  @overload
296
- def get_state(self, key: str, key2: str, *keys: str,
297
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
296
+ def get_state(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
297
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
298
298
  cls: type[ListLike] = list) -> list[ListLike]: ...
299
299
 
300
- def get_state(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
301
- params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
300
+ def get_state(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
301
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
302
302
  cls: type[ListLike] = list) -> ListLike | list[ListLike]:
303
303
  """Returns values of per-parameter state for a given key.
304
304
  If key doesn't exist, create it with inits.
@@ -404,8 +404,8 @@ class Module(ABC):
404
404
 
405
405
  # ---------------------------- OVERRIDABLE METHODS --------------------------- #
406
406
  @abstractmethod
407
- def step(self, vars: Vars) -> Vars:
408
- """performs a step, returns new vars but may update them in-place."""
407
+ def step(self, var: Var) -> Var:
408
+ """performs a step, returns new var but may update them in-place."""
409
409
 
410
410
  def reset(self):
411
411
  """Resets the internal state of the module (e.g. momentum)."""
@@ -556,13 +556,13 @@ class Modular(torch.optim.Optimizer):
556
556
  if not p.requires_grad: continue
557
557
  for map in self._per_parameter_global_settings[p]: map.update(settings)
558
558
 
559
- # create vars
559
+ # create var
560
560
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
561
- vars = Vars(params=params, closure=closure, model=self.model, current_step=self.current_step)
561
+ var = Var(params=params, closure=closure, model=self.model, current_step=self.current_step)
562
562
 
563
563
  # if closure is None, assume backward has been called and gather grads
564
564
  if closure is None:
565
- vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
565
+ var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
566
566
 
567
567
  last_module = self.modules[-1]
568
568
  last_lr = last_module.defaults.get('lr', None)
@@ -570,27 +570,27 @@ class Modular(torch.optim.Optimizer):
570
570
 
571
571
  # step
572
572
  for i, module in enumerate(self.modules):
573
- if i!=0: vars = vars.clone(clone_update=False)
573
+ if i!=0: var = var.clone(clone_update=False)
574
574
 
575
575
  # last module, or next to last module before lr
576
576
  if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
577
- if module.children: vars.nested_is_last = True
578
- else: vars.is_last = True
579
- if last_lr is not None: vars.last_module_lrs = last_module.get_settings('lr', params=vars.params)
577
+ if module.children: var.nested_is_last = True
578
+ else: var.is_last = True
579
+ if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
580
580
 
581
- vars = module.step(vars)
582
- if vars.stop: break
581
+ var = module.step(var)
582
+ if var.stop: break
583
583
 
584
584
  # apply update
585
- if not vars.skip_update:
585
+ if not var.skip_update:
586
586
  with torch.no_grad():
587
- torch._foreach_sub_(params, vars.get_update())
587
+ torch._foreach_sub_(params, var.get_update())
588
588
 
589
- for hook in vars.post_step_hooks:
590
- hook(self, vars)
589
+ for hook in var.post_step_hooks:
590
+ hook(self, var)
591
591
 
592
592
  self.current_step += 1
593
- return vars.loss if vars.loss is not None else vars.loss_approx
593
+ return var.loss if var.loss is not None else var.loss_approx
594
594
 
595
595
  def __repr__(self):
596
596
  return f'Modular({", ".join(str(m) for m in self.modules)})'
@@ -606,11 +606,11 @@ class Chain(Module):
606
606
  for i, module in enumerate(flat_modules):
607
607
  self.set_child(f'module_{i}', module)
608
608
 
609
- def step(self, vars):
609
+ def step(self, var):
610
610
  for i in range(len(self.children)):
611
- vars = self.children[f'module_{i}'].step(vars)
612
- if vars.stop: break
613
- return vars
611
+ var = self.children[f'module_{i}'].step(var)
612
+ if var.stop: break
613
+ return var
614
614
 
615
615
  def __repr__(self):
616
616
  s = self.__class__.__name__