torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
tests/test_opts.py CHANGED
@@ -56,14 +56,17 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
56
56
  if use_closure:
57
57
  def closure(backward=True):
58
58
  loss = objective()
59
+ losses.append(loss.detach())
59
60
  if backward:
60
61
  opt.zero_grad()
61
62
  loss.backward()
62
63
  return loss
63
- loss = opt.step(closure)
64
- assert loss is not None
65
- assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
66
- losses.append(loss)
64
+ ret = opt.step(closure)
65
+ assert ret is not None # the return should be the loss
66
+ with torch.no_grad():
67
+ loss = objective() # in case f(x_0) is not evaluated
68
+ assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
69
+ losses.append(loss.detach())
67
70
 
68
71
  else:
69
72
  loss = objective()
@@ -71,7 +74,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
71
74
  loss.backward()
72
75
  opt.step()
73
76
  assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
74
- losses.append(loss)
77
+ losses.append(loss.detach())
75
78
 
76
79
  losses.append(objective())
77
80
  return torch.stack(losses).nan_to_num(0,10000,10000).min()
@@ -374,6 +377,21 @@ RandomizedFDM_central4 = Run(
374
377
  func='booth', steps=50, loss=10, merge_invariant=True,
375
378
  sphere_steps=100, sphere_loss=450,
376
379
  )
380
+ RandomizedFDM_forward4 = Run(
381
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
382
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
383
+ needs_closure=True,
384
+ func='booth', steps=50, loss=10, merge_invariant=True,
385
+ sphere_steps=100, sphere_loss=450,
386
+ )
387
+ RandomizedFDM_forward5 = Run(
388
+ func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
389
+ sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
390
+ needs_closure=True,
391
+ func='booth', steps=50, loss=10, merge_invariant=True,
392
+ sphere_steps=100, sphere_loss=450,
393
+ )
394
+
377
395
 
378
396
  RandomizedFDM_4samples = Run(
379
397
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, seed=0), tz.m.LR(0.1)),
@@ -455,25 +473,11 @@ Backtracking = Run(
455
473
  func='booth', steps=50, loss=0, merge_invariant=True,
456
474
  sphere_steps=2, sphere_loss=0,
457
475
  )
458
- Backtracking_try_negative = Run(
459
- func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
460
- sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.Backtracking(try_negative=True)),
461
- needs_closure=True,
462
- func='booth', steps=50, loss=1e-9, merge_invariant=True,
463
- sphere_steps=2, sphere_loss=1e-10,
464
- )
465
476
  AdaptiveBacktracking = Run(
466
477
  func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
467
478
  sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
468
479
  needs_closure=True,
469
- func='booth', steps=50, loss=1e-12, merge_invariant=True,
470
- sphere_steps=2, sphere_loss=1e-10,
471
- )
472
- AdaptiveBacktracking_try_negative = Run(
473
- func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
474
- sphere_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
475
- needs_closure=True,
476
- func='booth', steps=50, loss=1e-8, merge_invariant=True,
480
+ func='booth', steps=50, loss=1e-11, merge_invariant=True,
477
481
  sphere_steps=2, sphere_loss=1e-10,
478
482
  )
479
483
  # ----------------------------- line_search/scipy ---------------------------- #
@@ -578,8 +582,8 @@ UpdateGradientSignConsistency = Run(
578
582
  sphere_steps=10, sphere_loss=2,
579
583
  )
580
584
  IntermoduleCautious = Run(
581
- func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.01)),
582
- sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.1)),
585
+ func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.01)),
586
+ sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_restart=True)), tz.m.LR(0.1)),
583
587
  needs_closure=False,
584
588
  func='booth', steps=50, loss=1e-4, merge_invariant=True,
585
589
  sphere_steps=10, sphere_loss=0.1,
@@ -592,8 +596,8 @@ ScaleByGradCosineSimilarity = Run(
592
596
  sphere_steps=10, sphere_loss=0.1,
593
597
  )
594
598
  ScaleModulesByCosineSimilarity = Run(
595
- func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.05)),
596
- sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.1)),
599
+ func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.05)),
600
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_restart=True)),tz.m.LR(0.1)),
597
601
  needs_closure=False,
598
602
  func='booth', steps=50, loss=0.005, merge_invariant=True,
599
603
  sphere_steps=10, sphere_loss=0.1,
@@ -601,47 +605,69 @@ ScaleModulesByCosineSimilarity = Run(
601
605
 
602
606
  # ------------------------- momentum/matrix_momentum ------------------------- #
603
607
  MatrixMomentum_forward = Run(
604
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
605
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
608
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='forward'),),
609
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward')),
606
610
  needs_closure=True,
607
611
  func='booth', steps=50, loss=0.05, merge_invariant=True,
608
- sphere_steps=10, sphere_loss=0,
612
+ sphere_steps=10, sphere_loss=0.01,
609
613
  )
610
614
  MatrixMomentum_forward = Run(
611
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
612
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
615
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='central')),
616
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central')),
613
617
  needs_closure=True,
614
618
  func='booth', steps=50, loss=0.05, merge_invariant=True,
615
- sphere_steps=10, sphere_loss=0,
619
+ sphere_steps=10, sphere_loss=0.01,
616
620
  )
617
621
  MatrixMomentum_forward = Run(
618
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
619
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
622
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='autograd')),
623
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd')),
620
624
  needs_closure=True,
621
625
  func='booth', steps=50, loss=0.05, merge_invariant=True,
622
- sphere_steps=10, sphere_loss=0,
626
+ sphere_steps=10, sphere_loss=0.01,
623
627
  )
624
628
 
625
629
  AdaptiveMatrixMomentum_forward = Run(
626
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
627
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
630
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True)),
631
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True)),
628
632
  needs_closure=True,
629
- func='booth', steps=50, loss=0.002, merge_invariant=True,
630
- sphere_steps=10, sphere_loss=0,
633
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
634
+ sphere_steps=10, sphere_loss=0.05,
631
635
  )
632
636
  AdaptiveMatrixMomentum_central = Run(
633
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
634
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
637
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True)),
638
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True)),
635
639
  needs_closure=True,
636
- func='booth', steps=50, loss=0.002, merge_invariant=True,
637
- sphere_steps=10, sphere_loss=0,
640
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
641
+ sphere_steps=10, sphere_loss=0.05,
638
642
  )
639
643
  AdaptiveMatrixMomentum_autograd = Run(
640
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
641
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
644
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True)),
645
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True)),
642
646
  needs_closure=True,
643
- func='booth', steps=50, loss=0.002, merge_invariant=True,
644
- sphere_steps=10, sphere_loss=0,
647
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
648
+ sphere_steps=10, sphere_loss=0.05,
649
+ )
650
+
651
+ StochasticAdaptiveMatrixMomentum_forward = Run(
652
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True, adapt_freq=1)),
653
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True, adapt_freq=1)),
654
+ needs_closure=True,
655
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
656
+ sphere_steps=10, sphere_loss=0.05,
657
+ )
658
+ StochasticAdaptiveMatrixMomentum_central = Run(
659
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True, adapt_freq=1)),
660
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True, adapt_freq=1)),
661
+ needs_closure=True,
662
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
663
+ sphere_steps=10, sphere_loss=0.05,
664
+ )
665
+ StochasticAdaptiveMatrixMomentum_autograd = Run(
666
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='autograd', adaptive=True, adapt_freq=1)),
667
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='autograd', adaptive=True, adapt_freq=1)),
668
+ needs_closure=True,
669
+ func='booth', steps=50, loss=0.05, merge_invariant=True,
670
+ sphere_steps=10, sphere_loss=0.05,
645
671
  )
646
672
 
647
673
  # EMA, momentum are covered by test_identical
@@ -668,8 +694,8 @@ UpdateSign = Run(
668
694
  sphere_steps=10, sphere_loss=0,
669
695
  )
670
696
  GradAccumulation = Run(
671
- func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.05), 10), ),
672
- sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(tz.m.LR(0.5), 10), ),
697
+ func_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.05)),
698
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GradientAccumulation(n=10), tz.m.LR(0.5)),
673
699
  needs_closure=False,
674
700
  func='booth', steps=50, loss=25, merge_invariant=True,
675
701
  sphere_steps=20, sphere_loss=1e-11,
@@ -725,24 +751,24 @@ Shampoo = Run(
725
751
 
726
752
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
727
753
  BFGS = Run(
728
- func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
729
- sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
754
+ func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
755
+ sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_restart=True), tz.m.StrongWolfe()),
730
756
  needs_closure=True,
731
757
  func='rosen', steps=50, loss=1e-10, merge_invariant=True,
732
758
  sphere_steps=10, sphere_loss=1e-10,
733
759
  )
734
760
  SR1 = Run(
735
- func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
736
- sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
761
+ func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_restart=True, scale_first=True), tz.m.StrongWolfe(fallback=False)),
762
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
737
763
  needs_closure=True,
738
764
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
739
765
  sphere_steps=10, sphere_loss=0,
740
766
  )
741
767
  SSVM = Run(
742
- func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
743
- sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
768
+ func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
769
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_restart=True), tz.m.StrongWolfe(fallback=True)),
744
770
  needs_closure=True,
745
- func='rosen', steps=50, loss=0.5, merge_invariant=True,
771
+ func='rosen', steps=50, loss=0.2, merge_invariant=True,
746
772
  sphere_steps=10, sphere_loss=0,
747
773
  )
748
774
 
@@ -757,8 +783,8 @@ LBFGS = Run(
757
783
 
758
784
  # ----------------------------- quasi_newton/lsr1 ---------------------------- #
759
785
  LSR1 = Run(
760
- func_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
761
- sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
786
+ func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
787
+ sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe(c2=0.1, fallback=True)),
762
788
  needs_closure=True,
763
789
  func='rosen', steps=50, loss=0, merge_invariant=True,
764
790
  sphere_steps=10, sphere_loss=0,
@@ -775,8 +801,8 @@ LSR1 = Run(
775
801
 
776
802
  # ---------------------------- second_order/newton --------------------------- #
777
803
  Newton = Run(
778
- func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
779
- sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe()),
804
+ func_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
805
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Newton(), tz.m.StrongWolfe(fallback=True)),
780
806
  needs_closure=True,
781
807
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
782
808
  sphere_steps=2, sphere_loss=1e-9,
@@ -784,8 +810,8 @@ Newton = Run(
784
810
 
785
811
  # --------------------------- second_order/newton_cg -------------------------- #
786
812
  NewtonCG = Run(
787
- func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
788
- sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
813
+ func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
814
+ sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
789
815
  needs_closure=True,
790
816
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
791
817
  sphere_steps=2, sphere_loss=3e-4,
@@ -793,11 +819,11 @@ NewtonCG = Run(
793
819
 
794
820
  # ---------------------------- smoothing/gaussian ---------------------------- #
795
821
  GaussianHomotopy = Run(
796
- func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
797
- sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
822
+ func_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
823
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GradientSampling([tz.m.BFGS(), tz.m.Backtracking()], 1e-1, 10, termination=tz.m.TerminateByUpdateNorm(1e-1), seed=0)),
798
824
  needs_closure=True,
799
- func='booth', steps=20, loss=0.1, merge_invariant=True,
800
- sphere_steps=10, sphere_loss=200,
825
+ func='booth', steps=20, loss=0.01, merge_invariant=True,
826
+ sphere_steps=10, sphere_loss=1,
801
827
  )
802
828
 
803
829
  # ---------------------------- smoothing/laplacian --------------------------- #
@@ -879,14 +905,14 @@ Adan = Run(
879
905
  )
880
906
 
881
907
  # ------------------------------------ CGs ----------------------------------- #
882
- for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
908
+ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.DYHS, tz.m.ProjectedGradientMethod):
883
909
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
884
910
  # but also test 10 to make sure it doesn't explode after converging
885
911
  Run(
886
912
  func_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
887
913
  sphere_opt=lambda p: tz.Modular(p, CG(), tz.m.StrongWolfe(c2=0.1)),
888
914
  needs_closure=True,
889
- func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=False, # strong wolfe adds float imprecision
915
+ func='lstsq', steps=func_steps, loss=1e-10, merge_invariant=True,
890
916
  sphere_steps=sphere_steps_, sphere_loss=0,
891
917
  )
892
918
 
@@ -917,10 +943,10 @@ for QN in (
917
943
  tz.m.SSVM,
918
944
  ):
919
945
  Run(
920
- func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
921
- sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
946
+ func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
947
+ sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_restart=True), tz.m.StrongWolfe()),
922
948
  needs_closure=True,
923
- func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
949
+ func='lstsq', steps=50, loss=1e-10, merge_invariant=True,
924
950
  sphere_steps=10, sphere_loss=1e-20,
925
951
  )
926
952
 
tests/test_tensorlist.py CHANGED
@@ -977,22 +977,23 @@ def test_rademacher_like(big_tl: TensorList):
977
977
 
978
978
  @pytest.mark.parametrize("dist", ['normal', 'uniform', 'sphere', 'rademacher'])
979
979
  def test_sample_like(simple_tl: TensorList, dist):
980
- eps_scalar = 2.0
981
- result_tl_scalar = simple_tl.sample_like(eps_scalar, distribution=dist)
980
+ eps_scalar = 1
981
+ result_tl_scalar = simple_tl.sample_like(distribution=dist)
982
982
  assert isinstance(result_tl_scalar, TensorList)
983
983
  assert result_tl_scalar.shape == simple_tl.shape
984
984
 
985
- eps_list = [0.5, 1.0, 1.5]
986
- result_tl_list = simple_tl.sample_like(eps_list, distribution=dist)
985
+ eps_list = [1.0,]
986
+ result_tl_list = simple_tl.sample_like(distribution=dist)
987
987
  assert isinstance(result_tl_list, TensorList)
988
988
  assert result_tl_list.shape == simple_tl.shape
989
989
 
990
990
  # Basic checks based on distribution
991
991
  if dist == 'uniform':
992
- assert all(torch.all((t >= -eps_scalar/2) & (t <= eps_scalar/2)) for t in result_tl_scalar)
993
- assert all(torch.all((t >= -e/2) & (t <= e/2)) for t, e in zip(result_tl_list, eps_list))
992
+ assert all(torch.all((t >= -eps_scalar) & (t <= eps_scalar)) for t in result_tl_scalar)
993
+ assert all(torch.all((t >= -e) & (t <= e)) for t, e in zip(result_tl_list, eps_list))
994
994
  elif dist == 'sphere':
995
- assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
995
+ # assert torch.allclose(result_tl_scalar.global_vector_norm(), torch.tensor(eps_scalar))
996
+ pass
996
997
  # Cannot check list version easily
997
998
  elif dist == 'rademacher':
998
999
  assert all(torch.all((t == -eps_scalar) | (t == eps_scalar)) for t in result_tl_scalar)
torchzero/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from . import core, optim, utils
2
2
  from .core import Modular
3
- from .utils import compile
3
+ from .utils import set_compilation
4
4
  from . import modules as m
@@ -1,2 +1,2 @@
1
- from .module import Var, Module, Modular, Chain, maybe_chain, Chainable
2
- from .transform import Transform, TensorwiseTransform, Target, apply_transform
1
+ from .module import Chain, Chainable, Modular, Module, Var, maybe_chain
2
+ from .transform import Target, TensorwiseTransform, Transform, apply_transform