torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
docs/source/conf.py CHANGED
@@ -6,10 +6,10 @@
6
6
  # -- Project information -----------------------------------------------------
7
7
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
8
  import sys, os
9
- #sys.path.insert(0, os.path.abspath('.../src'))
9
+ #sys.path.insert(0, os.path.abspath('.../src'))
10
10
 
11
11
  project = 'torchzero'
12
- copyright = '2024, Ivan Nikishev'
12
+ copyright = '2025, Ivan Nikishev'
13
13
  author = 'Ivan Nikishev'
14
14
 
15
15
  # -- General configuration ---------------------------------------------------
@@ -24,10 +24,12 @@ extensions = [
24
24
  'sphinx.ext.githubpages',
25
25
  'sphinx.ext.napoleon',
26
26
  'autoapi.extension',
27
+ "myst_nb",
28
+
27
29
  # 'sphinx_rtd_theme',
28
30
  ]
29
31
  autosummary_generate = True
30
- autoapi_dirs = ['../../src']
32
+ autoapi_dirs = ['../../torchzero']
31
33
  autoapi_type = "python"
32
34
  # autoapi_ignore = ["*/tensorlist.py"]
33
35
 
@@ -48,7 +50,7 @@ exclude_patterns = []
48
50
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
49
51
 
50
52
  #html_theme = 'alabaster'
51
- html_theme = 'furo'
53
+ html_theme = 'sphinx_rtd_theme'
52
54
  html_static_path = ['_static']
53
55
 
54
56
 
@@ -0,0 +1,46 @@
1
+ class MyModule:
2
+ """[One-line summary of the class].
3
+
4
+ [A more detailed description of the class, explaining its purpose, how it
5
+ works, and its typical use cases. You can use multiple paragraphs.]
6
+
7
+ .. note::
8
+ [Optional: Add important notes, warnings, or usage guidelines here.
9
+ For example, you could mention if a closure is required, discuss
10
+ stability, or highlight performance characteristics. Use the `.. note::`
11
+ directive to make it stand out in the documentation.]
12
+
13
+ Args:
14
+ param1 (type, optional):
15
+ [Description of the first parameter. Use :code:`backticks` for
16
+ inline code like variable names or specific values like ``"autograd"``.
17
+ Explain what the parameter does.] Defaults to [value].
18
+ param2 (type):
19
+ [Description of a mandatory parameter (no "optional" or "Defaults to").]
20
+ **kwargs:
21
+ [If you accept keyword arguments, describe what they are used for.]
22
+
23
+ Examples:
24
+ [A title or short sentence describing the first example]:
25
+
26
+ .. code-block:: python
27
+
28
+ opt = tz.Modular(
29
+ model.parameters(),
30
+ ...
31
+ )
32
+
33
+ [A title or short sentence for a second, different example]:
34
+
35
+ .. code-block:: python
36
+
37
+ opt = tz.Modular(
38
+ model.parameters(),
39
+ ...
40
+ )
41
+
42
+ References:
43
+ - [Optional: A citation for a relevant paper, book, or algorithm.]
44
+ - [Optional: A link to a blog post or website with more information.]
45
+
46
+ """
tests/test_identical.py CHANGED
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
96
96
 
97
97
  @pytest.mark.parametrize('amsgrad', [True, False])
98
98
  def test_adam(amsgrad):
99
- # torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
- # pytorch applies debiasing separately so it is applied before epsilo
99
+ torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
101
100
  tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
102
101
  tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
103
102
  tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
133
132
  tz.m.Debias2(beta=0.999),
134
133
  tz.m.Add(1e-8)]
135
134
  ))
136
- tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
135
+ tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
137
136
 
138
137
  _assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
139
138
  for fn in tz_fns:
tests/test_opts.py CHANGED
@@ -292,42 +292,42 @@ FDM_central2 = Run(
292
292
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
293
293
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
294
294
  needs_closure=True,
295
- func='booth', steps=50, loss=1e-7, merge_invariant=True,
295
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
296
296
  sphere_steps=2, sphere_loss=340,
297
297
  )
298
298
  FDM_forward2 = Run(
299
299
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
300
300
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
301
301
  needs_closure=True,
302
- func='booth', steps=50, loss=1e-7, merge_invariant=True,
302
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
303
303
  sphere_steps=2, sphere_loss=340,
304
304
  )
305
305
  FDM_backward2 = Run(
306
306
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
307
307
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
308
308
  needs_closure=True,
309
- func='booth', steps=50, loss=2e-7, merge_invariant=True,
309
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
310
310
  sphere_steps=2, sphere_loss=340,
311
311
  )
312
312
  FDM_forward3 = Run(
313
313
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
314
314
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
315
315
  needs_closure=True,
316
- func='booth', steps=50, loss=3e-7, merge_invariant=True,
316
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
317
317
  sphere_steps=2, sphere_loss=340,
318
318
  )
319
319
  FDM_backward3 = Run(
320
320
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
321
321
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
322
322
  needs_closure=True,
323
- func='booth', steps=50, loss=3e-7, merge_invariant=True,
323
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
324
324
  sphere_steps=2, sphere_loss=340,
325
325
  )
326
326
  FDM_central4 = Run(
327
327
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
328
328
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
329
329
  needs_closure=True,
330
- func='booth', steps=50, loss=2e-8, merge_invariant=True,
330
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
331
331
  sphere_steps=2, sphere_loss=340,
332
332
  )
333
333
 
@@ -466,8 +466,8 @@ AdaptiveBacktracking = Run(
466
466
  func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
467
467
  sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
468
468
  needs_closure=True,
469
- func='booth', steps=50, loss=0, merge_invariant=True,
470
- sphere_steps=2, sphere_loss=0,
469
+ func='booth', steps=50, loss=1e-12, merge_invariant=True,
470
+ sphere_steps=2, sphere_loss=1e-10,
471
471
  )
472
472
  AdaptiveBacktracking_try_negative = Run(
473
473
  func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
@@ -494,15 +494,6 @@ StrongWolfe = Run(
494
494
  sphere_steps=2, sphere_loss=0,
495
495
  )
496
496
 
497
- # ------------------------- line_search/trust_region ------------------------- #
498
- TrustRegion = Run(
499
- func_opt=lambda p: tz.Modular(p, tz.m.TrustRegion()),
500
- sphere_opt=lambda p: tz.Modular(p, tz.m.TrustRegion(init=0.1)),
501
- needs_closure=True,
502
- func='booth', steps=50, loss=0.1, merge_invariant=True,
503
- sphere_steps=10, sphere_loss=1e-5,
504
- )
505
-
506
497
  # ----------------------------------- lr/lr ---------------------------------- #
507
498
  LR = Run(
508
499
  func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
@@ -587,8 +578,8 @@ UpdateGradientSignConsistency = Run(
587
578
  sphere_steps=10, sphere_loss=2,
588
579
  )
589
580
  IntermoduleCautious = Run(
590
- func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.01)),
591
- sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.1)),
581
+ func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.01)),
582
+ sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.1)),
592
583
  needs_closure=False,
593
584
  func='booth', steps=50, loss=1e-4, merge_invariant=True,
594
585
  sphere_steps=10, sphere_loss=0.1,
@@ -601,8 +592,8 @@ ScaleByGradCosineSimilarity = Run(
601
592
  sphere_steps=10, sphere_loss=0.1,
602
593
  )
603
594
  ScaleModulesByCosineSimilarity = Run(
604
- func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.05)),
605
- sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.1)),
595
+ func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.05)),
596
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.1)),
606
597
  needs_closure=False,
607
598
  func='booth', steps=50, loss=0.005, merge_invariant=True,
608
599
  sphere_steps=10, sphere_loss=0.1,
@@ -684,8 +675,8 @@ GradAccumulation = Run(
684
675
  sphere_steps=20, sphere_loss=1e-11,
685
676
  )
686
677
  NegateOnLossIncrease = Run(
687
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(),),
688
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(),),
678
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
679
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
689
680
  needs_closure=True,
690
681
  func='booth', steps=50, loss=0.1, merge_invariant=True,
691
682
  sphere_steps=20, sphere_loss=0.001,
@@ -693,7 +684,7 @@ NegateOnLossIncrease = Run(
693
684
  # -------------------------------- misc/switch ------------------------------- #
694
685
  Alternate = Run(
695
686
  func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
696
- sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
687
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
697
688
  needs_closure=False,
698
689
  func='booth', steps=50, loss=1, merge_invariant=True,
699
690
  sphere_steps=20, sphere_loss=20,
@@ -734,24 +725,24 @@ Shampoo = Run(
734
725
 
735
726
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
736
727
  BFGS = Run(
737
- func_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
738
- sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
728
+ func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
729
+ sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
739
730
  needs_closure=True,
740
- func='rosen', steps=50, loss=0, merge_invariant=True,
741
- sphere_steps=10, sphere_loss=0,
731
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
732
+ sphere_steps=10, sphere_loss=1e-10,
742
733
  )
743
734
  SR1 = Run(
744
- func_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
745
- sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
735
+ func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
736
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
746
737
  needs_closure=True,
747
738
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
748
739
  sphere_steps=10, sphere_loss=0,
749
740
  )
750
741
  SSVM = Run(
751
- func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
752
- sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
742
+ func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
743
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
753
744
  needs_closure=True,
754
- func='rosen', steps=50, loss=1e-10, merge_invariant=True,
745
+ func='rosen', steps=50, loss=0.5, merge_invariant=True,
755
746
  sphere_steps=10, sphere_loss=0,
756
747
  )
757
748
 
@@ -766,21 +757,21 @@ LBFGS = Run(
766
757
 
767
758
  # ----------------------------- quasi_newton/lsr1 ---------------------------- #
768
759
  LSR1 = Run(
769
- func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
770
- sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
760
+ func_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
761
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
771
762
  needs_closure=True,
772
763
  func='rosen', steps=50, loss=0, merge_invariant=True,
773
764
  sphere_steps=10, sphere_loss=0,
774
765
  )
775
766
 
776
- # ---------------------------- quasi_newton/olbfgs --------------------------- #
777
- OnlineLBFGS = Run(
778
- func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
779
- sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
780
- needs_closure=True,
781
- func='rosen', steps=50, loss=0, merge_invariant=True,
782
- sphere_steps=10, sphere_loss=0,
783
- )
767
+ # # ---------------------------- quasi_newton/olbfgs --------------------------- #
768
+ # OnlineLBFGS = Run(
769
+ # func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
770
+ # sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
771
+ # needs_closure=True,
772
+ # func='rosen', steps=50, loss=0, merge_invariant=True,
773
+ # sphere_steps=10, sphere_loss=0,
774
+ # )
784
775
 
785
776
  # ---------------------------- second_order/newton --------------------------- #
786
777
  Newton = Run(
@@ -802,8 +793,8 @@ NewtonCG = Run(
802
793
 
803
794
  # ---------------------------- smoothing/gaussian ---------------------------- #
804
795
  GaussianHomotopy = Run(
805
- func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
806
- sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
796
+ func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
797
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
807
798
  needs_closure=True,
808
799
  func='booth', steps=20, loss=0.1, merge_invariant=True,
809
800
  sphere_steps=10, sphere_loss=200,
@@ -860,7 +851,7 @@ SophiaH = Run(
860
851
  sphere_steps=10, sphere_loss=40,
861
852
  )
862
853
 
863
- # -------------------------- optimizers/higher_order ------------------------- #
854
+ # -------------------------- higher_order ------------------------- #
864
855
  HigherOrderNewton = Run(
865
856
  func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
866
857
  sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
@@ -869,6 +860,24 @@ HigherOrderNewton = Run(
869
860
  sphere_steps=1, sphere_loss=1e-10,
870
861
  )
871
862
 
863
+ # ---------------------------- optimizers/ladagrad --------------------------- #
864
+ LMAdagrad = Run(
865
+ func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
866
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
867
+ needs_closure=False,
868
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
869
+ sphere_steps=20, sphere_loss=1e-9,
870
+ )
871
+
872
+ # ------------------------------ optimizers/adan ----------------------------- #
873
+ Adan = Run(
874
+ func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
875
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
876
+ needs_closure=False,
877
+ func='booth', steps=50, loss=60, merge_invariant=True,
878
+ sphere_steps=20, sphere_loss=60,
879
+ )
880
+
872
881
  # ------------------------------------ CGs ----------------------------------- #
873
882
  for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
874
883
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
@@ -885,17 +894,22 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
885
894
  # stability test
886
895
  for QN in (
887
896
  tz.m.BFGS,
897
+ partial(tz.m.BFGS, inverse=False),
888
898
  tz.m.SR1,
899
+ partial(tz.m.SR1, inverse=False),
889
900
  tz.m.DFP,
901
+ partial(tz.m.DFP, inverse=False),
890
902
  tz.m.BroydenGood,
903
+ partial(tz.m.BroydenGood, inverse=False),
891
904
  tz.m.BroydenBad,
905
+ partial(tz.m.BroydenBad, inverse=False),
892
906
  tz.m.Greenstadt1,
893
907
  tz.m.Greenstadt2,
894
- tz.m.ColumnUpdatingMethod,
908
+ tz.m.ICUM,
895
909
  tz.m.ThomasOptimalMethod,
896
910
  tz.m.FletcherVMM,
897
911
  tz.m.Horisho,
898
- lambda scale_first: tz.m.Horisho(scale_first=scale_first, inner=tz.m.GradientCorrection()),
912
+ partial(tz.m.Horisho, inner=tz.m.GradientCorrection()),
899
913
  tz.m.Pearson,
900
914
  tz.m.ProjectedNewtonRaphson,
901
915
  tz.m.PSB,
@@ -903,8 +917,8 @@ for QN in (
903
917
  tz.m.SSVM,
904
918
  ):
905
919
  Run(
906
- func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
907
- sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
920
+ func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
921
+ sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
908
922
  needs_closure=True,
909
923
  func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
910
924
  sphere_steps=10, sphere_loss=1e-20,
tests/test_vars.py CHANGED
@@ -156,6 +156,7 @@ def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
156
156
  for k,v in v1.__dict__.items():
157
157
  if not k.startswith('__'):
158
158
  # if k == 'post_step_hooks': continue
159
+ if k == 'storage': continue
159
160
  if k == 'update' and clone_update:
160
161
  if v1.update is None or v2.update is None:
161
162
  assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
torchzero/core/module.py CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
4
  from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any, final, overload
6
+ from typing import Any, final, overload, Literal
7
7
 
8
8
  import torch
9
9
 
@@ -14,6 +14,7 @@ from ..utils import (
14
14
  _make_param_groups,
15
15
  get_state_vals,
16
16
  )
17
+ from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
17
18
  from ..utils.python_tools import flatten
18
19
 
19
20
 
@@ -109,6 +110,9 @@ class Var:
109
110
  self.skip_update: bool = False
110
111
  """if True, the parameters will not be updated"""
111
112
 
113
+ self.storage: dict = {}
114
+ """Storage for any other data, such as hessian estimates, etc"""
115
+
112
116
  def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
113
117
  """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
114
118
  Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
@@ -186,6 +190,7 @@ class Var:
186
190
  if self.loss is None: self.loss = var.loss
187
191
  if self.loss_approx is None: self.loss_approx = var.loss_approx
188
192
  if self.grad is None: self.grad = var.grad
193
+ self.storage.update(var.storage)
189
194
 
190
195
  def zero_grad(self, set_to_none=True):
191
196
  if set_to_none:
@@ -358,6 +363,26 @@ class Module(ABC):
358
363
  # # if isinstance(params, Vars): params = params.params
359
364
  # return itemgetter(*keys)(self.settings[params[0]])
360
365
 
366
+ def clear_state_keys(self, *keys:str):
367
+ for s in self.state.values():
368
+ for k in keys:
369
+ if k in s: del s[k]
370
+
371
+ @overload
372
+ def store(self, params: Sequence[torch.Tensor], keys: str, values: Sequence): ...
373
+ @overload
374
+ def store(self, params: Sequence[torch.Tensor], keys: Sequence[str], values: Sequence[Sequence]): ...
375
+ def store(self, params: Sequence[torch.Tensor], keys: str | Sequence[str], values: Sequence):
376
+ if isinstance(keys, str):
377
+ for p,v in zip(params, values):
378
+ state = self.state[p]
379
+ state[keys] = v
380
+ return
381
+
382
+ for p, *p_v in zip(params, *values):
383
+ state = self.state[p]
384
+ for k,v in zip(keys, p_v): state[k] = v
385
+
361
386
  def state_dict(self):
362
387
  """state dict"""
363
388
  packed_state = {id(k):v for k,v in self.state.items()}
@@ -403,23 +428,111 @@ class Module(ABC):
403
428
  self._extra_unpack(state_dict['extra'])
404
429
 
405
430
  # ---------------------------- OVERRIDABLE METHODS --------------------------- #
406
- @abstractmethod
407
431
  def step(self, var: Var) -> Var:
408
- """performs a step, returns new var but may update them in-place."""
432
+ """performs a step, returns new var but may update it in-place."""
433
+ self.update(var)
434
+ return self.apply(var)
435
+
436
+ def update(self, var:Var) -> Any:
437
+ """Updates the internal state of this module. This should not modify `var.update`.
438
+
439
+ Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
440
+ such as ::code::`tz.m.Online`.
441
+ """
442
+
443
+ def apply(self, var: Var) -> Var:
444
+ """Applies this module to ``var.get_update()``. This should not modify the internal state of this module if possible."""
445
+ raise NotImplementedError(f"{self} doesn't implement the `apply` method.")
409
446
 
410
447
  def reset(self):
411
- """Resets the internal state of the module (e.g. momentum)."""
448
+ """Resets the internal state of the module (e.g. momentum). By default clears state and global state."""
412
449
  # no complex logic is allowed there because this is overridden by many modules
413
450
  # where super().reset() shouldn't be called
414
451
  self.state.clear()
415
452
  self.global_state.clear()
416
453
 
454
+ def reset_for_online(self):
455
+ """resets only the intermediate state of this module, e.g. previous parameters and gradient."""
456
+ for c in self.children.values(): c.reset_for_online()
457
+
417
458
  def _extra_pack(self):
418
459
  return {}
419
460
 
420
461
  def _extra_unpack(self, x):
421
462
  pass
422
463
 
464
+
465
+ # ------------------------------ HELPER METHODS ------------------------------ #
466
+ @torch.no_grad
467
+ def Hvp(
468
+ self,
469
+ v: Sequence[torch.Tensor],
470
+ at_x0: bool,
471
+ var: Var,
472
+ rgrad: Sequence[torch.Tensor] | None,
473
+ hvp_method: Literal['autograd', 'forward', 'central'],
474
+ h: float,
475
+ normalize: bool,
476
+ retain_grad: bool,
477
+ ):
478
+ """
479
+ Returns ``(Hvp, rgrad)``. ``rgrad`` is gradient at current parameters, possibly with create_graph=True, or it may be None with ``hvp_method="central"``. Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
480
+
481
+ Single sample example:
482
+
483
+ .. code:: py
484
+
485
+ Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
486
+
487
+ Multiple samples example:
488
+
489
+ .. code:: py
490
+
491
+ D = None
492
+ rgrad = None
493
+ for i in range(n_samples):
494
+ v = [torch.randn_like(p) for p in params]
495
+ Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
496
+
497
+ if D is None: D = Hvp
498
+ else: torch._foreach_add_(D, Hvp)
499
+
500
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
501
+ Args:
502
+ v (Sequence[torch.Tensor]): vector in hessian-vector product
503
+ at_x0 (bool): whether this is being called at original or perturbed parameters.
504
+ var (Var): Var
505
+ rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
506
+ hvp_method (str): hvp method.
507
+ h (float): finite difference step size
508
+ normalize (bool): whether to normalize v for finite difference
509
+ retain_grad (bool): retain grad
510
+ """
511
+ # get grad
512
+ if rgrad is None and hvp_method in ('autograd', 'forward'):
513
+ if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
514
+ else:
515
+ if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
516
+ with torch.enable_grad():
517
+ loss = var.closure()
518
+ rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
519
+
520
+ if hvp_method == 'autograd':
521
+ assert rgrad is not None
522
+ Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
523
+
524
+ elif hvp_method == 'forward':
525
+ assert rgrad is not None
526
+ loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
527
+
528
+ elif hvp_method == 'central':
529
+ loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
530
+
531
+ else:
532
+ raise ValueError(hvp_method)
533
+
534
+ return Hvp, rgrad
535
+
423
536
  # endregion
424
537
 
425
538
  Chainable = Module | Sequence[Module]
@@ -440,6 +553,21 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
440
553
 
441
554
  # region Modular
442
555
  # ---------------------------------- Modular --------------------------------- #
556
+
557
+ class _EvalCounterClosure:
558
+ """keeps track of how many times closure has been evaluated"""
559
+ __slots__ = ("modular", "closure")
560
+ def __init__(self, modular: "Modular", closure):
561
+ self.modular = modular
562
+ self.closure = closure
563
+
564
+ def __call__(self, *args, **kwargs):
565
+ if self.closure is None:
566
+ raise RuntimeError("One of the modules requires closure to be passed to the step method")
567
+
568
+ self.modular.num_evaluations += 1
569
+ return self.closure(*args, **kwargs)
570
+
443
571
  # have to inherit from Modular to support lr schedulers
444
572
  # although Accelerate doesn't work due to converting param_groups to a dict
445
573
  class Modular(torch.optim.Optimizer):
@@ -496,7 +624,10 @@ class Modular(torch.optim.Optimizer):
496
624
  # self.add_param_group(param_group)
497
625
 
498
626
  self.current_step = 0
499
- """The global step counter for the optimizer."""
627
+ """global step counter for the optimizer."""
628
+
629
+ self.num_evaluations = 0
630
+ """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
500
631
 
501
632
  def add_param_group(self, param_group: dict[str, Any]):
502
633
  proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
@@ -558,11 +689,12 @@ class Modular(torch.optim.Optimizer):
558
689
 
559
690
  # create var
560
691
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
561
- var = Var(params=params, closure=closure, model=self.model, current_step=self.current_step)
692
+ var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step)
562
693
 
563
694
  # if closure is None, assume backward has been called and gather grads
564
695
  if closure is None:
565
696
  var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
697
+ self.num_evaluations += 1
566
698
 
567
699
  last_module = self.modules[-1]
568
700
  last_lr = last_module.defaults.get('lr', None)