torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
docs/source/conf.py CHANGED
@@ -6,10 +6,10 @@
6
6
  # -- Project information -----------------------------------------------------
7
7
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
8
  import sys, os
9
- #sys.path.insert(0, os.path.abspath('.../src'))
9
+ #sys.path.insert(0, os.path.abspath('.../src'))
10
10
 
11
11
  project = 'torchzero'
12
- copyright = '2024, Ivan Nikishev'
12
+ copyright = '2025, Ivan Nikishev'
13
13
  author = 'Ivan Nikishev'
14
14
 
15
15
  # -- General configuration ---------------------------------------------------
@@ -24,10 +24,12 @@ extensions = [
24
24
  'sphinx.ext.githubpages',
25
25
  'sphinx.ext.napoleon',
26
26
  'autoapi.extension',
27
+ "myst_nb",
28
+
27
29
  # 'sphinx_rtd_theme',
28
30
  ]
29
31
  autosummary_generate = True
30
- autoapi_dirs = ['../../src']
32
+ autoapi_dirs = ['../../torchzero']
31
33
  autoapi_type = "python"
32
34
  # autoapi_ignore = ["*/tensorlist.py"]
33
35
 
@@ -48,7 +50,7 @@ exclude_patterns = []
48
50
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
49
51
 
50
52
  #html_theme = 'alabaster'
51
- html_theme = 'furo'
53
+ html_theme = 'sphinx_rtd_theme'
52
54
  html_static_path = ['_static']
53
55
 
54
56
 
@@ -0,0 +1,46 @@
1
+ class MyModule:
2
+ """[One-line summary of the class].
3
+
4
+ [A more detailed description of the class, explaining its purpose, how it
5
+ works, and its typical use cases. You can use multiple paragraphs.]
6
+
7
+ .. note::
8
+ [Optional: Add important notes, warnings, or usage guidelines here.
9
+ For example, you could mention if a closure is required, discuss
10
+ stability, or highlight performance characteristics. Use the `.. note::`
11
+ directive to make it stand out in the documentation.]
12
+
13
+ Args:
14
+ param1 (type, optional):
15
+ [Description of the first parameter. Use :code:`backticks` for
16
+ inline code like variable names or specific values like ``"autograd"``.
17
+ Explain what the parameter does.] Defaults to [value].
18
+ param2 (type):
19
+ [Description of a mandatory parameter (no "optional" or "Defaults to").]
20
+ **kwargs:
21
+ [If you accept keyword arguments, describe what they are used for.]
22
+
23
+ Examples:
24
+ [A title or short sentence describing the first example]:
25
+
26
+ .. code-block:: python
27
+
28
+ opt = tz.Modular(
29
+ model.parameters(),
30
+ ...
31
+ )
32
+
33
+ [A title or short sentence for a second, different example]:
34
+
35
+ .. code-block:: python
36
+
37
+ opt = tz.Modular(
38
+ model.parameters(),
39
+ ...
40
+ )
41
+
42
+ References:
43
+ - [Optional: A citation for a relevant paper, book, or algorithm.]
44
+ - [Optional: A link to a blog post or website with more information.]
45
+
46
+ """
tests/test_identical.py CHANGED
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
96
96
 
97
97
  @pytest.mark.parametrize('amsgrad', [True, False])
98
98
  def test_adam(amsgrad):
99
- # torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
- # pytorch applies debiasing separately so it is applied before epsilo
99
+ torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
101
100
  tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
102
101
  tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
103
102
  tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
133
132
  tz.m.Debias2(beta=0.999),
134
133
  tz.m.Add(1e-8)]
135
134
  ))
136
- tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
135
+ tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
137
136
 
138
137
  _assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
139
138
  for fn in tz_fns:
tests/test_opts.py CHANGED
@@ -1,4 +1,9 @@
1
- """snity tests to make sure everything works and converges on basic functions"""
1
+ """
2
+ Sanity tests to make sure everything works.
3
+
4
+ This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
5
+ don't error or become unhinged with different parameter shapes.
6
+ """
2
7
  from collections.abc import Callable
3
8
  from functools import partial
4
9
 
@@ -68,6 +73,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
68
73
  assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
69
74
  losses.append(loss)
70
75
 
76
+ losses.append(objective())
71
77
  return torch.stack(losses).nan_to_num(0,10000,10000).min()
72
78
 
73
79
  def _run_func(opt_fn: Callable, func:str, merge: bool, use_closure: bool, steps: int):
@@ -286,42 +292,42 @@ FDM_central2 = Run(
286
292
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
287
293
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
288
294
  needs_closure=True,
289
- func='booth', steps=50, loss=1e-7, merge_invariant=True,
295
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
290
296
  sphere_steps=2, sphere_loss=340,
291
297
  )
292
298
  FDM_forward2 = Run(
293
299
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
294
300
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
295
301
  needs_closure=True,
296
- func='booth', steps=50, loss=1e-7, merge_invariant=True,
302
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
297
303
  sphere_steps=2, sphere_loss=340,
298
304
  )
299
305
  FDM_backward2 = Run(
300
306
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
301
307
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
302
308
  needs_closure=True,
303
- func='booth', steps=50, loss=2e-7, merge_invariant=True,
309
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
304
310
  sphere_steps=2, sphere_loss=340,
305
311
  )
306
312
  FDM_forward3 = Run(
307
313
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
308
314
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
309
315
  needs_closure=True,
310
- func='booth', steps=50, loss=3e-7, merge_invariant=True,
316
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
311
317
  sphere_steps=2, sphere_loss=340,
312
318
  )
313
319
  FDM_backward3 = Run(
314
320
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
315
321
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
316
322
  needs_closure=True,
317
- func='booth', steps=50, loss=3e-7, merge_invariant=True,
323
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
318
324
  sphere_steps=2, sphere_loss=340,
319
325
  )
320
326
  FDM_central4 = Run(
321
327
  func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
322
328
  sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
323
329
  needs_closure=True,
324
- func='booth', steps=50, loss=2e-8, merge_invariant=True,
330
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
325
331
  sphere_steps=2, sphere_loss=340,
326
332
  )
327
333
 
@@ -460,8 +466,8 @@ AdaptiveBacktracking = Run(
460
466
  func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
461
467
  sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
462
468
  needs_closure=True,
463
- func='booth', steps=50, loss=0, merge_invariant=True,
464
- sphere_steps=2, sphere_loss=0,
469
+ func='booth', steps=50, loss=1e-12, merge_invariant=True,
470
+ sphere_steps=2, sphere_loss=1e-10,
465
471
  )
466
472
  AdaptiveBacktracking_try_negative = Run(
467
473
  func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
@@ -488,15 +494,6 @@ StrongWolfe = Run(
488
494
  sphere_steps=2, sphere_loss=0,
489
495
  )
490
496
 
491
- # ------------------------- line_search/trust_region ------------------------- #
492
- TrustRegion = Run(
493
- func_opt=lambda p: tz.Modular(p, tz.m.TrustRegion()),
494
- sphere_opt=lambda p: tz.Modular(p, tz.m.TrustRegion(init=0.1)),
495
- needs_closure=True,
496
- func='booth', steps=50, loss=0.1, merge_invariant=True,
497
- sphere_steps=10, sphere_loss=1e-5,
498
- )
499
-
500
497
  # ----------------------------------- lr/lr ---------------------------------- #
501
498
  LR = Run(
502
499
  func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
@@ -524,7 +521,7 @@ PolyakStepSize = Run(
524
521
  func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
525
522
  sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
526
523
  needs_closure=True,
527
- func='booth', steps=50, loss=1e-11, merge_invariant=True,
524
+ func='booth', steps=50, loss=1e-7, merge_invariant=True,
528
525
  sphere_steps=10, sphere_loss=0.002,
529
526
  )
530
527
  RandomStepSize = Run(
@@ -581,8 +578,8 @@ UpdateGradientSignConsistency = Run(
581
578
  sphere_steps=10, sphere_loss=2,
582
579
  )
583
580
  IntermoduleCautious = Run(
584
- func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.01)),
585
- sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.1)),
581
+ func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.01)),
582
+ sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.1)),
586
583
  needs_closure=False,
587
584
  func='booth', steps=50, loss=1e-4, merge_invariant=True,
588
585
  sphere_steps=10, sphere_loss=0.1,
@@ -595,8 +592,8 @@ ScaleByGradCosineSimilarity = Run(
595
592
  sphere_steps=10, sphere_loss=0.1,
596
593
  )
597
594
  ScaleModulesByCosineSimilarity = Run(
598
- func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.05)),
599
- sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.1)),
595
+ func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.05)),
596
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.1)),
600
597
  needs_closure=False,
601
598
  func='booth', steps=50, loss=0.005, merge_invariant=True,
602
599
  sphere_steps=10, sphere_loss=0.1,
@@ -604,44 +601,44 @@ ScaleModulesByCosineSimilarity = Run(
604
601
 
605
602
  # ------------------------- momentum/matrix_momentum ------------------------- #
606
603
  MatrixMomentum_forward = Run(
607
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.01)),
608
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
604
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
605
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
609
606
  needs_closure=True,
610
607
  func='booth', steps=50, loss=0.05, merge_invariant=True,
611
608
  sphere_steps=10, sphere_loss=0,
612
609
  )
613
610
  MatrixMomentum_forward = Run(
614
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.01)),
615
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
611
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
612
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
616
613
  needs_closure=True,
617
614
  func='booth', steps=50, loss=0.05, merge_invariant=True,
618
615
  sphere_steps=10, sphere_loss=0,
619
616
  )
620
617
  MatrixMomentum_forward = Run(
621
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.01)),
622
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
618
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
619
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
623
620
  needs_closure=True,
624
621
  func='booth', steps=50, loss=0.05, merge_invariant=True,
625
622
  sphere_steps=10, sphere_loss=0,
626
623
  )
627
624
 
628
625
  AdaptiveMatrixMomentum_forward = Run(
629
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.05)),
630
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
626
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
627
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
631
628
  needs_closure=True,
632
629
  func='booth', steps=50, loss=0.002, merge_invariant=True,
633
630
  sphere_steps=10, sphere_loss=0,
634
631
  )
635
632
  AdaptiveMatrixMomentum_central = Run(
636
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.05)),
637
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
633
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
634
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
638
635
  needs_closure=True,
639
636
  func='booth', steps=50, loss=0.002, merge_invariant=True,
640
637
  sphere_steps=10, sphere_loss=0,
641
638
  )
642
639
  AdaptiveMatrixMomentum_autograd = Run(
643
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.05)),
644
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
640
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
641
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
645
642
  needs_closure=True,
646
643
  func='booth', steps=50, loss=0.002, merge_invariant=True,
647
644
  sphere_steps=10, sphere_loss=0,
@@ -678,8 +675,8 @@ GradAccumulation = Run(
678
675
  sphere_steps=20, sphere_loss=1e-11,
679
676
  )
680
677
  NegateOnLossIncrease = Run(
681
- func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(),),
682
- sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(),),
678
+ func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
679
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
683
680
  needs_closure=True,
684
681
  func='booth', steps=50, loss=0.1, merge_invariant=True,
685
682
  sphere_steps=20, sphere_loss=0.001,
@@ -687,7 +684,7 @@ NegateOnLossIncrease = Run(
687
684
  # -------------------------------- misc/switch ------------------------------- #
688
685
  Alternate = Run(
689
686
  func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
690
- sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
687
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
691
688
  needs_closure=False,
692
689
  func='booth', steps=50, loss=1, merge_invariant=True,
693
690
  sphere_steps=20, sphere_loss=20,
@@ -719,33 +716,33 @@ Lion = Run(
719
716
  )
720
717
  # ---------------------------- optimizers/shampoo ---------------------------- #
721
718
  Shampoo = Run(
722
- func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
723
- sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.2)),
719
+ func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
720
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
724
721
  needs_closure=False,
725
- func='booth', steps=50, loss=200, merge_invariant=False,
726
- sphere_steps=20, sphere_loss=1e-3, # merge and unmerge lrs are very different so need to test convergence separately somewhere
722
+ func='booth', steps=50, loss=0.02, merge_invariant=False,
723
+ sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
727
724
  )
728
725
 
729
726
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
730
727
  BFGS = Run(
731
- func_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
732
- sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
728
+ func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
729
+ sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
733
730
  needs_closure=True,
734
- func='rosen', steps=50, loss=0, merge_invariant=True,
735
- sphere_steps=10, sphere_loss=0,
731
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
732
+ sphere_steps=10, sphere_loss=1e-10,
736
733
  )
737
734
  SR1 = Run(
738
- func_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
739
- sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
735
+ func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
736
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
740
737
  needs_closure=True,
741
738
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
742
739
  sphere_steps=10, sphere_loss=0,
743
740
  )
744
741
  SSVM = Run(
745
- func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
- sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
742
+ func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
743
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
747
744
  needs_closure=True,
748
- func='rosen', steps=50, loss=1e-10, merge_invariant=True,
745
+ func='rosen', steps=50, loss=0.5, merge_invariant=True,
749
746
  sphere_steps=10, sphere_loss=0,
750
747
  )
751
748
 
@@ -760,21 +757,21 @@ LBFGS = Run(
760
757
 
761
758
  # ----------------------------- quasi_newton/lsr1 ---------------------------- #
762
759
  LSR1 = Run(
763
- func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
764
- sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
760
+ func_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
761
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
765
762
  needs_closure=True,
766
763
  func='rosen', steps=50, loss=0, merge_invariant=True,
767
764
  sphere_steps=10, sphere_loss=0,
768
765
  )
769
766
 
770
- # ---------------------------- quasi_newton/olbfgs --------------------------- #
771
- OnlineLBFGS = Run(
772
- func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
773
- sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
774
- needs_closure=True,
775
- func='rosen', steps=50, loss=0, merge_invariant=True,
776
- sphere_steps=10, sphere_loss=0,
777
- )
767
+ # # ---------------------------- quasi_newton/olbfgs --------------------------- #
768
+ # OnlineLBFGS = Run(
769
+ # func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
770
+ # sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
771
+ # needs_closure=True,
772
+ # func='rosen', steps=50, loss=0, merge_invariant=True,
773
+ # sphere_steps=10, sphere_loss=0,
774
+ # )
778
775
 
779
776
  # ---------------------------- second_order/newton --------------------------- #
780
777
  Newton = Run(
@@ -791,13 +788,13 @@ NewtonCG = Run(
791
788
  sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
792
789
  needs_closure=True,
793
790
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
794
- sphere_steps=2, sphere_loss=1e-6,
791
+ sphere_steps=2, sphere_loss=3e-4,
795
792
  )
796
793
 
797
794
  # ---------------------------- smoothing/gaussian ---------------------------- #
798
795
  GaussianHomotopy = Run(
799
- func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
800
- sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
796
+ func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
797
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
801
798
  needs_closure=True,
802
799
  func='booth', steps=20, loss=0.1, merge_invariant=True,
803
800
  sphere_steps=10, sphere_loss=200,
@@ -854,8 +851,35 @@ SophiaH = Run(
854
851
  sphere_steps=10, sphere_loss=40,
855
852
  )
856
853
 
854
+ # -------------------------- higher_order ------------------------- #
855
+ HigherOrderNewton = Run(
856
+ func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
857
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
858
+ needs_closure=True,
859
+ func='rosen', steps=1, loss=2e-10, merge_invariant=True,
860
+ sphere_steps=1, sphere_loss=1e-10,
861
+ )
862
+
863
+ # ---------------------------- optimizers/ladagrad --------------------------- #
864
+ LMAdagrad = Run(
865
+ func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
866
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
867
+ needs_closure=False,
868
+ func='booth', steps=50, loss=1e-6, merge_invariant=True,
869
+ sphere_steps=20, sphere_loss=1e-9,
870
+ )
871
+
872
+ # ------------------------------ optimizers/adan ----------------------------- #
873
+ Adan = Run(
874
+ func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
875
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
876
+ needs_closure=False,
877
+ func='booth', steps=50, loss=60, merge_invariant=True,
878
+ sphere_steps=20, sphere_loss=60,
879
+ )
880
+
857
881
  # ------------------------------------ CGs ----------------------------------- #
858
- for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY):
882
+ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
859
883
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
860
884
  # but also test 10 to make sure it doesn't explode after converging
861
885
  Run(
@@ -868,10 +892,33 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
868
892
 
869
893
  # ------------------------------- QN stability ------------------------------- #
870
894
  # stability test
871
- for QN in (tz.m.BFGS, tz.m.SR1, tz.m.DFP, tz.m.BroydenGood, tz.m.BroydenBad, tz.m.Greenstadt1, tz.m.Greenstadt2, tz.m.ColumnUpdatingMethod, tz.m.ThomasOptimalMethod, tz.m.PSB, tz.m.Pearson2, tz.m.SSVM):
895
+ for QN in (
896
+ tz.m.BFGS,
897
+ partial(tz.m.BFGS, inverse=False),
898
+ tz.m.SR1,
899
+ partial(tz.m.SR1, inverse=False),
900
+ tz.m.DFP,
901
+ partial(tz.m.DFP, inverse=False),
902
+ tz.m.BroydenGood,
903
+ partial(tz.m.BroydenGood, inverse=False),
904
+ tz.m.BroydenBad,
905
+ partial(tz.m.BroydenBad, inverse=False),
906
+ tz.m.Greenstadt1,
907
+ tz.m.Greenstadt2,
908
+ tz.m.ICUM,
909
+ tz.m.ThomasOptimalMethod,
910
+ tz.m.FletcherVMM,
911
+ tz.m.Horisho,
912
+ partial(tz.m.Horisho, inner=tz.m.GradientCorrection()),
913
+ tz.m.Pearson,
914
+ tz.m.ProjectedNewtonRaphson,
915
+ tz.m.PSB,
916
+ tz.m.McCormick,
917
+ tz.m.SSVM,
918
+ ):
872
919
  Run(
873
- func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
874
- sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
920
+ func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
921
+ sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
875
922
  needs_closure=True,
876
923
  func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
877
924
  sphere_steps=10, sphere_loss=1e-20,
tests/test_tensorlist.py CHANGED
@@ -1261,8 +1261,8 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
1261
1261
  elif reduction_method == 'quantile': expected = vec.quantile(q)
1262
1262
  else:
1263
1263
  pytest.fail("Unknown global reduction")
1264
- assert False, 'sus'
1265
- assert torch.allclose(result, expected)
1264
+ assert False, reduction_method
1265
+ assert torch.allclose(result, expected, atol=1e-4)
1266
1266
  else:
1267
1267
  expected_list = []
1268
1268
  for t in simple_tl: