torchzero 0.3.5__tar.gz → 0.3.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. {torchzero-0.3.5 → torchzero-0.3.8}/PKG-INFO +2 -2
  2. {torchzero-0.3.5 → torchzero-0.3.8}/README.md +1 -1
  3. {torchzero-0.3.5 → torchzero-0.3.8}/pyproject.toml +2 -2
  4. {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_opts.py +1 -1
  5. {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_tensorlist.py +17 -17
  6. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/preconditioner.py +11 -10
  7. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/__init__.py +3 -2
  8. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/absoap.py +8 -2
  9. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/adadam.py +1 -1
  10. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/adamY.py +1 -1
  11. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/adasoap.py +1 -1
  12. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/algebraic_newton.py +1 -1
  13. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/curveball.py +1 -1
  14. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/gradmin.py +1 -1
  15. torchzero-0.3.8/torchzero/modules/experimental/newton_solver.py +88 -0
  16. torchzero-0.3.5/torchzero/modules/experimental/dsoap.py → torchzero-0.3.8/torchzero/modules/experimental/soapy.py +4 -4
  17. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/spectral.py +5 -3
  18. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/subspace_preconditioners.py +16 -9
  19. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/soap.py +1 -2
  20. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/projection.py +27 -1
  21. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
  22. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/lbfgs.py +4 -3
  23. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/lsr1.py +6 -3
  24. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/quasi_newton.py +16 -17
  25. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/__init__.py +1 -1
  26. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/newton_cg.py +1 -1
  27. torchzero-0.3.8/torchzero/utils/linalg/benchmark.py +20 -0
  28. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/solve.py +15 -14
  29. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/PKG-INFO +2 -2
  30. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/SOURCES.txt +3 -1
  31. {torchzero-0.3.5 → torchzero-0.3.8}/LICENSE +0 -0
  32. {torchzero-0.3.5 → torchzero-0.3.8}/docs/source/conf.py +0 -0
  33. {torchzero-0.3.5 → torchzero-0.3.8}/setup.cfg +0 -0
  34. {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_identical.py +0 -0
  35. {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_module.py +0 -0
  36. {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_utils_optimizer.py +0 -0
  37. {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_vars.py +0 -0
  38. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/__init__.py +0 -0
  39. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/__init__.py +0 -0
  40. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/module.py +0 -0
  41. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/transform.py +0 -0
  42. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/__init__.py +0 -0
  43. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/__init__.py +0 -0
  44. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/clipping.py +0 -0
  45. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/ema_clipping.py +0 -0
  46. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/growth_clipping.py +0 -0
  47. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  48. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/tropical_newton.py +0 -0
  49. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/functional.py +0 -0
  50. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/__init__.py +0 -0
  51. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/fdm.py +0 -0
  52. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
  53. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
  54. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/rfdm.py +0 -0
  55. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/__init__.py +0 -0
  56. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/backtracking.py +0 -0
  57. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/line_search.py +0 -0
  58. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/scipy.py +0 -0
  59. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/strong_wolfe.py +0 -0
  60. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/trust_region.py +0 -0
  61. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/lr/__init__.py +0 -0
  62. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/lr/lr.py +0 -0
  63. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/lr/step_size.py +0 -0
  64. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/__init__.py +0 -0
  65. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/averaging.py +0 -0
  66. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/cautious.py +0 -0
  67. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/ema.py +0 -0
  68. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/experimental.py +0 -0
  69. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/matrix_momentum.py +0 -0
  70. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/momentum.py +0 -0
  71. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/__init__.py +0 -0
  72. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/accumulate.py +0 -0
  73. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/binary.py +0 -0
  74. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/debug.py +0 -0
  75. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/misc.py +0 -0
  76. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/multi.py +0 -0
  77. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/reduce.py +0 -0
  78. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/split.py +0 -0
  79. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/switch.py +0 -0
  80. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/unary.py +0 -0
  81. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/utility.py +0 -0
  82. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/__init__.py +0 -0
  83. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/adagrad.py +0 -0
  84. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/adam.py +0 -0
  85. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/lion.py +0 -0
  86. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/muon.py +0 -0
  87. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/orthograd.py +0 -0
  88. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/rmsprop.py +0 -0
  89. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/rprop.py +0 -0
  90. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/shampoo.py +0 -0
  91. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/sophia_h.py +0 -0
  92. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/__init__.py +0 -0
  93. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/dct.py +0 -0
  94. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/fft.py +0 -0
  95. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/galore.py +0 -0
  96. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/structural.py +0 -0
  97. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/__init__.py +0 -0
  98. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/cg.py +0 -0
  99. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
  100. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/olbfgs.py +0 -0
  101. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/newton.py +0 -0
  102. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/nystrom.py +0 -0
  103. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/smoothing/__init__.py +0 -0
  104. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/smoothing/gaussian.py +0 -0
  105. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/smoothing/laplacian.py +0 -0
  106. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/weight_decay/__init__.py +0 -0
  107. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/weight_decay/weight_decay.py +0 -0
  108. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/wrappers/__init__.py +0 -0
  109. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
  110. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/__init__.py +0 -0
  111. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/utility/__init__.py +0 -0
  112. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/utility/split.py +0 -0
  113. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/__init__.py +0 -0
  114. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/nevergrad.py +0 -0
  115. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/nlopt.py +0 -0
  116. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/scipy.py +0 -0
  117. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/__init__.py +0 -0
  118. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/compile.py +0 -0
  119. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/derivatives.py +0 -0
  120. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/__init__.py +0 -0
  121. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/matrix_funcs.py +0 -0
  122. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/orthogonalize.py +0 -0
  123. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/qr.py +0 -0
  124. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/svd.py +0 -0
  125. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/numberlist.py +0 -0
  126. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/ops.py +0 -0
  127. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/optimizer.py +0 -0
  128. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/optuna_tools.py +0 -0
  129. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/params.py +0 -0
  130. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/python_tools.py +0 -0
  131. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/tensorlist.py +0 -0
  132. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/torch_tools.py +0 -0
  133. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/dependency_links.txt +0 -0
  134. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/requires.txt +0 -0
  135. {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.5
3
+ Version: 0.3.8
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -156,7 +156,7 @@ for epoch in range(100):
156
156
  * `Newton`: Classic Newton's method.
157
157
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
158
158
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
- * `NystromPCG`: NewtonCG with Nyström preconditioning (my current recommendation).
159
+ * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
160
160
 
161
161
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
162
  * `LBFGS`: Limited-memory BFGS.
@@ -117,7 +117,7 @@ for epoch in range(100):
117
117
  * `Newton`: Classic Newton's method.
118
118
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
119
119
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
120
- * `NystromPCG`: NewtonCG with Nyström preconditioning (my current recommendation).
120
+ * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
121
121
 
122
122
  * **Quasi-Newton**: Approximate second-order optimization methods.
123
123
  * `LBFGS`: Limited-memory BFGS.
@@ -2,7 +2,7 @@
2
2
  # STEP 1 - COMMIT NEW CHANGES BUT DON'T PUSH THEM YET
3
3
  # STEP 2 - BUMP VERSION AND COMMIT IT (DONT PUSH!!!!)
4
4
  # STEP 3 - CREATE TAG WITH THAT VERSION
5
- # STEP 4 - PUSH CHANGES
5
+ # STEP 4 - PUSH (SYNC) CHANGES
6
6
  # STEP 5 - PUSH TAG
7
7
 
8
8
  [build-system]
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.5"
16
+ version = "0.3.8"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -745,7 +745,7 @@ SSVM = Run(
745
745
  func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
746
  sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
747
  needs_closure=True,
748
- func='rosen', steps=50, loss=1e-12, merge_invariant=True,
748
+ func='rosen', steps=50, loss=0.02, merge_invariant=True,
749
749
  sphere_steps=10, sphere_loss=0,
750
750
  )
751
751
 
@@ -1301,7 +1301,7 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
1301
1301
  expected_tl = TensorList(expected_list)
1302
1302
  assert isinstance(result, TensorList)
1303
1303
  assert len(result) == len(expected_tl)
1304
- assert_tl_allclose(result, expected_tl, atol=1e-6) # Use allclose due to potential float variations
1304
+ assert_tl_allclose(result, expected_tl, atol=1e-3) # Use allclose due to potential float variations
1305
1305
 
1306
1306
  # --- Grafting, Rescaling, Normalizing, Clipping ---
1307
1307
 
@@ -1381,8 +1381,8 @@ def test_rescale(simple_tl: TensorList, dim):
1381
1381
  assert torch.allclose(rescaled_scalar.global_min(), torch.tensor(min_val))
1382
1382
  assert torch.allclose(rescaled_scalar.global_max(), torch.tensor(max_val))
1383
1383
  else:
1384
- assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-4)
1385
- assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-4)
1384
+ assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-3)
1385
+ assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-3)
1386
1386
 
1387
1387
 
1388
1388
  # Rescale list
@@ -1402,8 +1402,8 @@ def test_rescale(simple_tl: TensorList, dim):
1402
1402
  assert global_max_rescaled < avg_max + 1.0 # Loose check
1403
1403
 
1404
1404
  else:
1405
- assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-4)
1406
- assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-4)
1405
+ assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-3)
1406
+ assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-3)
1407
1407
 
1408
1408
  # Rescale to 01 helper
1409
1409
  rescaled_01 = simple_tl.rescale_to_01(dim=dim, eps=eps)
@@ -1413,8 +1413,8 @@ def test_rescale(simple_tl: TensorList, dim):
1413
1413
  assert torch.allclose(rescaled_01.global_min(), torch.tensor(0.0))
1414
1414
  assert torch.allclose(rescaled_01.global_max(), torch.tensor(1.0))
1415
1415
  else:
1416
- assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-4)
1417
- assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-4)
1416
+ assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-3)
1417
+ assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-3)
1418
1418
 
1419
1419
 
1420
1420
  # Test inplace
@@ -1454,11 +1454,11 @@ def test_normalize(big_tl: TensorList, dim):
1454
1454
  normalized_scalar_var = normalized_scalar.var(dim=dim if dim != 'global' else None)
1455
1455
 
1456
1456
  if dim == 'global':
1457
- assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-4)
1458
- assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-4)
1457
+ assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-3)
1458
+ assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-3)
1459
1459
  else:
1460
- assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-4)
1461
- assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-4)
1460
+ assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-3)
1461
+ assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-3)
1462
1462
 
1463
1463
  # Normalize list mean/var
1464
1464
  normalized_list = simple_tl.normalize(mean_list, var_list, dim=dim)
@@ -1476,19 +1476,19 @@ def test_normalize(big_tl: TensorList, dim):
1476
1476
  # assert torch.allclose(global_mean_rescaled, torch.tensor(avg_mean), rtol=1e-1, atol=1e-1) # Loose check
1477
1477
  # assert torch.allclose(global_var_rescaled, torch.tensor(avg_var), rtol=1e-1, atol=1e-1) # Loose check
1478
1478
  else:
1479
- assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-4)
1480
- assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-4)
1479
+ assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-3)
1480
+ assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-3)
1481
1481
 
1482
1482
  # Z-normalize helper
1483
1483
  znorm = simple_tl.znormalize(dim=dim, eps=1e-10)
1484
1484
  znorm_mean = znorm.mean(dim=dim if dim != 'global' else None)
1485
1485
  znorm_var = znorm.var(dim=dim if dim != 'global' else None)
1486
1486
  if dim == 'global':
1487
- assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-4)
1488
- assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-4)
1487
+ assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-3)
1488
+ assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-3)
1489
1489
  else:
1490
- assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-4)
1491
- assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-4)
1490
+ assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-3)
1491
+ assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-3)
1492
1492
 
1493
1493
 
1494
1494
  # Test inplace
@@ -38,7 +38,7 @@ class Preconditioner(Transform):
38
38
 
39
39
 
40
40
  def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
41
- step = self.global_state.get('step', 0)
41
+ step = self.global_state.get('__step', 0)
42
42
  states = [self.state[p] for p in params]
43
43
  settings = [self.settings[p] for p in params]
44
44
  global_settings = settings[0]
@@ -47,8 +47,10 @@ class Preconditioner(Transform):
47
47
  scale_first = global_settings['__scale_first']
48
48
  scale_factor = 0
49
49
  if scale_first and step == 0:
50
- # initial step size guess from pytorch LBFGS
51
- scale_factor = TensorList(tensors).abs().sum()
50
+ # initial step size guess from pytorch LBFGS was too unstable
51
+ # I switched to norm
52
+ tensors = TensorList(tensors)
53
+ scale_factor = tensors.abs().global_mean().clip(min=1)
52
54
 
53
55
  # update preconditioner
54
56
  if step % update_freq == 0:
@@ -65,11 +67,11 @@ class Preconditioner(Transform):
65
67
  if scale_first and step == 0:
66
68
  torch._foreach_div_(tensors, scale_factor)
67
69
 
68
- self.global_state['step'] = step + 1
70
+ self.global_state['__step'] = step + 1
69
71
  return tensors
70
72
 
71
73
  def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
72
- step = self.global_state.get('step', 0)
74
+ step = self.global_state.get('__step', 0)
73
75
  tensors_vec = torch.cat([t.ravel() for t in tensors])
74
76
  params_vec = torch.cat([p.ravel() for p in params])
75
77
  grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
@@ -82,8 +84,8 @@ class Preconditioner(Transform):
82
84
  scale_first = global_settings['__scale_first']
83
85
  scale_factor = 0
84
86
  if scale_first and step == 0:
85
- # initial step size guess from pytorch LBFGS
86
- scale_factor = tensors_vec.abs().sum()
87
+ # initial step size guess from pytorch LBFGS was too unstable
88
+ scale_factor = tensors_vec.abs().mean().clip(min=1)
87
89
 
88
90
  # update preconditioner
89
91
  if step % update_freq == 0:
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
99
101
 
100
102
  # scale initial step, when preconditioner might not have been applied
101
103
  if scale_first and step == 0:
102
- if scale_factor >= torch.finfo(tensors_vec.dtype).eps:
103
- tensors_vec /= scale_factor
104
+ tensors_vec /= scale_factor
104
105
 
105
106
  tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
106
- self.global_state['step'] = step + 1
107
+ self.global_state['__step'] = step + 1
107
108
  return tensors
108
109
 
109
110
  @torch.no_grad
@@ -3,7 +3,7 @@ from .adadam import Adadam
3
3
  from .adamY import AdamY
4
4
  from .adasoap import AdaSOAP
5
5
  from .curveball import CurveBall
6
- from .dsoap import DSOAP
6
+ from .soapy import SOAPY
7
7
  from .gradmin import GradMin
8
8
  from .reduce_outward_lr import ReduceOutwardLR
9
9
  from .spectral import SpectralPreconditioner
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
11
11
  HistorySubspacePreconditioning,
12
12
  RandomSubspacePreconditioning,
13
13
  )
14
- from .tropical_newton import TropicalNewton
14
+ from .tropical_newton import TropicalNewton
15
+ from .newton_solver import NewtonSolver
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
140
140
  class ABSOAP(Transform):
141
141
  """SOAP but with two extra letters included in its name in order to improve converence
142
142
 
143
+ so what you can do is choose what goes into what ,and that is supposed to be good.
144
+
143
145
  new args
144
146
 
145
147
  scale by s whether to scale gradient differences by parameter differences
146
148
 
147
149
  y_to_ema2 whether to use gradient differences for exponential moving average too
150
+
151
+ okay I changed these args into another ones
152
+
153
+ BASICALLY THIS IS FOR MY EXPERIMENTS
148
154
  """
149
155
  def __init__(
150
156
  self,
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
213
219
  if 'g_prev' not in state:
214
220
  state['p_prev'] = p.clone()
215
221
  state['g_prev'] = t.clone()
216
- updates.append(tensors[i].sign())
222
+ updates.append(tensors[i].clip(-0.1,0.1))
217
223
  continue
218
224
 
219
225
  p_prev = state['p_prev']
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
285
291
  state['Q'] = get_orthogonal_matrix(state['GG'])
286
292
 
287
293
  state['step'] = 0
288
- updates.append(tensors[i].sign())
294
+ updates.append(tensors[i].clip(-0.1,0.1))
289
295
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
290
296
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
291
297
 
@@ -50,7 +50,7 @@ def adadam_(
50
50
  return None
51
51
 
52
52
  class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner and a graceful name."""
53
+ """Adam with a diagonally preconditioned preconditioner."""
54
54
  def __init__(
55
55
  self,
56
56
  beta1: float = 0.9,
@@ -37,7 +37,7 @@ def adamy_(
37
37
  p_prev.copy_(p)
38
38
  g_prev.copy_(g)
39
39
 
40
- update = g.sign().lazy_mul_(alpha*0.1)
40
+ update = g.clip(-0.1,0.1).lazy_mul_(alpha)
41
41
  if params_ is None: return update
42
42
  params_.sub_(update)
43
43
  return None
@@ -218,7 +218,7 @@ class AdaSOAP(Transform):
218
218
  state['Q'] = get_orthogonal_matrix(GG_precond)
219
219
 
220
220
  state['step'] = 0
221
- updates.append(tensors[i].sign())
221
+ updates.append(tensors[i].clip(-0.1,0.1))
222
222
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
223
223
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
224
224
 
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
71
71
 
72
72
 
73
73
  class AlgebraicNewton(Module):
74
- """newton in other algebras, not practical because solving linear system is very hard."""
74
+ """newton in other algebras, not that it works."""
75
75
  def __init__(
76
76
  self,
77
77
  reg: float | None = None,
@@ -13,7 +13,7 @@ def curveball(
13
13
  momentum: float | NumberList,
14
14
  precond_lr: float | NumberList,
15
15
  ):
16
- """returns z_, clone it!!!"""
16
+ """returns z_, clone it!!! (no just negate it)"""
17
17
  delta = Hz + tensors
18
18
  z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
19
19
  return z_
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
14
14
 
15
15
 
16
16
  class GradMin(Reformulation):
17
- """Reformulates the objective to minimize sum of gradient magnitudes via autograd.
17
+ """Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
18
18
 
19
19
  Args:
20
20
  loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
@@ -0,0 +1,88 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, overload
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, apply, Modular
7
+ from ...utils import TensorList, as_tensorlist
8
+ from ...utils.derivatives import hvp
9
+ from ..quasi_newton import LBFGS
10
+
11
+ class NewtonSolver(Module):
12
+ """Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
13
+ def __init__(
14
+ self,
15
+ solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
16
+ maxiter=None,
17
+ tol=1e-3,
18
+ reg: float = 0,
19
+ warm_start=True,
20
+ inner: Chainable | None = None,
21
+ ):
22
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
23
+ super().__init__(defaults,)
24
+
25
+ if inner is not None:
26
+ self.set_child('inner', inner)
27
+
28
+ @torch.no_grad
29
+ def step(self, vars):
30
+ params = TensorList(vars.params)
31
+ closure = vars.closure
32
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
33
+
34
+ settings = self.settings[params[0]]
35
+ solver_cls = settings['solver']
36
+ maxiter = settings['maxiter']
37
+ tol = settings['tol']
38
+ reg = settings['reg']
39
+ warm_start = settings['warm_start']
40
+
41
+ # ---------------------- Hessian vector product function --------------------- #
42
+ grad = vars.get_grad(create_graph=True)
43
+
44
+ def H_mm(x):
45
+ with torch.enable_grad():
46
+ Hvp = TensorList(hvp(params, grad, x, create_graph=True))
47
+ if reg != 0: Hvp = Hvp + (x*reg)
48
+ return Hvp
49
+
50
+ # -------------------------------- inner step -------------------------------- #
51
+ b = as_tensorlist(grad)
52
+ if 'inner' in self.children:
53
+ b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
54
+
55
+ # ---------------------------------- run cg ---------------------------------- #
56
+ x0 = None
57
+ if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
58
+ if x0 is None: x = b.zeros_like().requires_grad_(True)
59
+ else: x = x0.clone().requires_grad_(True)
60
+
61
+ solver = solver_cls(x)
62
+ def lstsq_closure(backward=True):
63
+ Hx = H_mm(x)
64
+ loss = (Hx-b).pow(2).global_mean()
65
+ if backward:
66
+ solver.zero_grad()
67
+ loss.backward(inputs=x)
68
+ return loss
69
+
70
+ if maxiter is None: maxiter = b.global_numel()
71
+ loss = None
72
+ initial_loss = lstsq_closure(False)
73
+ if initial_loss > tol:
74
+ for i in range(maxiter):
75
+ loss = solver.step(lstsq_closure)
76
+ assert loss is not None
77
+ if min(loss, loss/initial_loss) < tol: break
78
+
79
+ print(f'{loss = }')
80
+
81
+ if warm_start:
82
+ assert x0 is not None
83
+ x0.copy_(x)
84
+
85
+ vars.update = x.detach()
86
+ return vars
87
+
88
+
@@ -3,7 +3,7 @@ from operator import itemgetter
3
3
  import torch
4
4
 
5
5
  from ...core import Chainable, Transform, apply
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
6
+ from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
7
 
8
8
  @torch.no_grad
9
9
  def update_soap_covariances_(
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
135
135
 
136
136
  return final, exp_avg_sq
137
137
 
138
- class DSOAP(Transform):
138
+ class SOAPY(Transform):
139
139
  """SOAP but uses scaled gradient differences
140
140
 
141
141
  new args
@@ -195,7 +195,7 @@ class DSOAP(Transform):
195
195
  if 'g_prev' not in state:
196
196
  state['p_prev'] = p.clone()
197
197
  state['g_prev'] = t.clone()
198
- updates.append(tensors[i].sign())
198
+ updates.append(tensors[i].clip(-0.1,0.1))
199
199
  continue
200
200
 
201
201
  p_prev = state['p_prev']
@@ -228,7 +228,7 @@ class DSOAP(Transform):
228
228
  state['Q'] = get_orthogonal_matrix(state['GG'])
229
229
 
230
230
  state['step'] = 0
231
- updates.append(tensors[i].sign())
231
+ updates.append(tensors[i].clip(-0.1,0.1))
232
232
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
233
233
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
234
234
 
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
194
194
  order (int, optional):
195
195
  whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
196
196
  solver (str, optional): what to use for whitening. Defaults to 'svd'.
197
- U_beta (float | None, optional): beta for U (probably a bad idea). Defaults to None.
198
- S_beta (float | None, optional): beta for S (probably a bad idea). Defaults to None.
197
+ A_beta (float | None, optional):
198
+ beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
+ B_beta (float | None, optional):
200
+ beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
201
  interval (int, optional): How often to update history. Defaults to 1 (every step).
200
202
  concat_params (bool, optional):
201
203
  whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
275
277
  A = state.get('A', None)
276
278
  if A is None:
277
279
  # make a conservative step to avoid issues due to different GD scaling
278
- return tensor.div_(max(1, tensor.abs().sum())) # pyright:ignore[reportArgumentType]
280
+ return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
279
281
 
280
282
  B = state['B']
281
283
  update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
38
38
  return basis @ update_projected # d
39
39
 
40
40
  class RandomSubspacePreconditioning(Transform):
41
- """full matrix rmsprop in random subspace"""
42
- def __init__(self, k: int, beta: float | None = 0.99):
43
- defaults = dict(k=k, beta=beta)
41
+ """full matrix rmsprop in random slowly changing subspace"""
42
+ def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
+ defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
44
  super().__init__(defaults, uses_grad=False)
45
45
 
46
+ if inner is not None: self.set_child('inner', inner)
47
+
46
48
  def transform(self, tensors, params, grads, vars):
47
49
  settings = self.settings[params[0]]
48
50
  g = torch.cat([t.view(-1) for t in tensors])
49
51
  k = settings['k']
50
52
  beta = settings['beta']
53
+ basis_beta = settings['basis_beta']
51
54
 
52
55
  if 'basis' not in self.global_state:
53
56
  self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
56
59
  basis = self.global_state['basis']
57
60
  accumulator = self.global_state['accumulator']
58
61
 
62
+ if basis_beta is not None:
63
+ basis.lerp_(torch.randn_like(basis), 1-basis_beta)
64
+
59
65
  update_subspace_preconditioner_(g, basis, accumulator, beta)
66
+
67
+ if 'inner' in self.children:
68
+ tensors = apply(self.children['inner'], tensors, params, grads, vars)
69
+ g = torch.cat([t.view(-1) for t in tensors])
70
+
60
71
  try:
61
72
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
62
73
  except torch.linalg.LinAlgError:
63
- denom = g.abs().sum()
64
- if denom <= 1e-10: denom = torch.ones_like(denom)
65
- preconditioned = g / g.abs().sum()
74
+ preconditioned = g.clip(-0.1, 0.1)
66
75
  vec_to_tensors_(preconditioned, tensors)
67
76
 
68
77
  return tensors
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
119
128
  try:
120
129
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
121
130
  except torch.linalg.LinAlgError:
122
- denom = g.abs().sum()
123
- if denom <= 1e-10: denom = torch.ones_like(denom)
124
- preconditioned = g / g.abs().sum()
131
+ preconditioned = g.clip(-0.1,0.1)
125
132
  vec_to_tensors_(preconditioned, tensors)
126
133
 
127
134
  return tensors
@@ -222,8 +222,7 @@ class SOAP(Transform):
222
222
  state['Q'] = get_orthogonal_matrix(state['GG'])
223
223
 
224
224
  state['step'] = 0
225
- updates.append(tensors[i].sign().div_(10))
226
- # updates.append(tensors[i] / tensors[i].abs().sum())
225
+ updates.append(tensors[i].clip(-0.1, 0.1))
227
226
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
228
227
  # I use scaled update instead as to not mess up with next modules.
229
228
 
@@ -1,4 +1,5 @@
1
1
  import math
2
+ from functools import partial
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Iterable
4
5
  from typing import Any, Literal
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
33
34
 
34
35
  return projected_closure
35
36
 
37
+ def _projected_get_grad_override(
38
+ retain_graph: bool | None = None,
39
+ create_graph: bool = False,
40
+ projection: Any = ...,
41
+ unprojected_vars: Any = ...,
42
+ self: Any = ...,
43
+ ):
44
+ assert isinstance(projection, Projection)
45
+ assert isinstance(unprojected_vars, Vars)
46
+ assert isinstance(self, Vars)
47
+
48
+ if self.grad is not None: return self.grad
49
+ grads = unprojected_vars.get_grad(retain_graph, create_graph)
50
+ projected_grads = list(projection.project(grads, self, current='grads'))
51
+ self.grad = projected_grads
52
+ for p, g in zip(self.params, projected_grads):
53
+ p.grad = g
54
+ return self.grad
55
+
36
56
 
37
57
  class Projection(Module, ABC):
38
58
  """
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
137
157
 
138
158
  # step
139
159
  projected_vars.params = self._projected_params
160
+ projected_vars.get_grad = partial(
161
+ _projected_get_grad_override,
162
+ projection=self,
163
+ unprojected_vars=vars,
164
+ self=projected_vars,
165
+ )
140
166
  projected_vars = self.children['modules'].step(projected_vars)
141
167
 
142
168
  # empty fake params storage
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
149
175
  unprojected_vars = projected_vars.clone(clone_update=False)
150
176
  unprojected_vars.closure = vars.closure
151
177
  unprojected_vars.params = vars.params
152
- if unprojected_vars.grad is None: unprojected_vars.grad = vars.grad
178
+ unprojected_vars.grad = vars.grad
153
179
 
154
180
  if self._project_update:
155
181
  assert projected_vars.update is not None
@@ -37,10 +37,11 @@ def lbfgs(
37
37
  z_tfm: Any,
38
38
  ):
39
39
  if len(s_history) == 0 or y_k is None or ys_k is None:
40
- # dir = params.grad.sign() # may work fine
41
40
 
42
- # initial step size guess taken from pytorch L-BFGS
43
- return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
41
+ # initial step size guess modified from pytorch L-BFGS
42
+ scale = 1 / tensors_.abs().global_sum()
43
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
44
+ return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
44
45
 
45
46
  else:
46
47
  # 1st loop
@@ -36,10 +36,11 @@ def lbfgs(
36
36
  step: int,
37
37
  ):
38
38
  if len(s_history) == 0 or y_k is None or ys_k is None:
39
- # dir = params.grad.sign() # may work fine
40
39
 
41
- # initial step size guess taken from pytorch L-BFGS
42
- return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
40
+ # initial step size guess modified from pytorch L-BFGS
41
+ scale = 1 / tensors_.abs().global_sum()
42
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
43
+ return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
43
44
 
44
45
  else:
45
46
  # 1st loop
@@ -17,8 +17,9 @@ def lsr1_(
17
17
  ):
18
18
  if step == 0 or not s_history:
19
19
  # initial step size guess from pytorch
20
- tensors_.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
21
- return tensors_
20
+ scale = 1 / tensors_.abs().global_sum()
21
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
22
+ return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
22
23
 
23
24
  m = len(s_history)
24
25
 
@@ -64,7 +65,9 @@ def lsr1_(
64
65
  Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
65
66
 
66
67
  if scale_second and step == 1:
67
- Hx.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
68
+ scale = 1 / tensors_.abs().global_sum()
69
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
70
+ Hx.mul_(min(1.0, scale)) # pyright:ignore[reportArgumentType]
68
71
  return Hx
69
72
 
70
73