torchzero 0.3.6__tar.gz → 0.3.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {torchzero-0.3.6 → torchzero-0.3.9}/PKG-INFO +2 -2
  2. {torchzero-0.3.6 → torchzero-0.3.9}/README.md +1 -1
  3. {torchzero-0.3.6 → torchzero-0.3.9}/pyproject.toml +2 -2
  4. {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_opts.py +1 -1
  5. {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_tensorlist.py +1 -1
  6. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/preconditioner.py +12 -11
  7. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/__init__.py +3 -2
  8. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/absoap.py +8 -2
  9. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/adadam.py +1 -1
  10. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/adamY.py +1 -1
  11. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/adasoap.py +2 -2
  12. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/algebraic_newton.py +1 -1
  13. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/curveball.py +1 -1
  14. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/gradmin.py +1 -1
  15. torchzero-0.3.9/torchzero/modules/experimental/newton_solver.py +88 -0
  16. torchzero-0.3.6/torchzero/modules/experimental/dsoap.py → torchzero-0.3.9/torchzero/modules/experimental/soapy.py +4 -4
  17. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/spectral.py +5 -3
  18. torchzero-0.3.9/torchzero/modules/experimental/structured_newton.py +111 -0
  19. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/subspace_preconditioners.py +16 -9
  20. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/soap.py +1 -2
  21. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/projection.py +27 -1
  22. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/cg.py +9 -9
  23. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
  24. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lbfgs.py +4 -3
  25. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lsr1.py +7 -3
  26. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/quasi_newton.py +18 -17
  27. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/__init__.py +1 -1
  28. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/newton.py +11 -6
  29. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/newton_cg.py +3 -3
  30. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/nystrom.py +6 -6
  31. torchzero-0.3.9/torchzero/utils/linalg/benchmark.py +20 -0
  32. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/solve.py +15 -14
  33. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/PKG-INFO +2 -2
  34. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/SOURCES.txt +4 -1
  35. {torchzero-0.3.6 → torchzero-0.3.9}/LICENSE +0 -0
  36. {torchzero-0.3.6 → torchzero-0.3.9}/docs/source/conf.py +0 -0
  37. {torchzero-0.3.6 → torchzero-0.3.9}/setup.cfg +0 -0
  38. {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_identical.py +0 -0
  39. {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_module.py +0 -0
  40. {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_utils_optimizer.py +0 -0
  41. {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_vars.py +0 -0
  42. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/__init__.py +0 -0
  43. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/__init__.py +0 -0
  44. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/module.py +0 -0
  45. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/transform.py +0 -0
  46. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/__init__.py +0 -0
  47. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/__init__.py +0 -0
  48. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/clipping.py +0 -0
  49. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/ema_clipping.py +0 -0
  50. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/growth_clipping.py +0 -0
  51. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  52. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/tropical_newton.py +0 -0
  53. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/functional.py +0 -0
  54. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/__init__.py +0 -0
  55. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/fdm.py +0 -0
  56. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
  57. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
  58. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/rfdm.py +0 -0
  59. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/__init__.py +0 -0
  60. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/backtracking.py +0 -0
  61. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/line_search.py +0 -0
  62. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/scipy.py +0 -0
  63. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/strong_wolfe.py +0 -0
  64. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/trust_region.py +0 -0
  65. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/lr/__init__.py +0 -0
  66. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/lr/lr.py +0 -0
  67. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/lr/step_size.py +0 -0
  68. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/__init__.py +0 -0
  69. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/averaging.py +0 -0
  70. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/cautious.py +0 -0
  71. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/ema.py +0 -0
  72. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/experimental.py +0 -0
  73. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/matrix_momentum.py +0 -0
  74. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/momentum.py +0 -0
  75. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/__init__.py +0 -0
  76. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/accumulate.py +0 -0
  77. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/binary.py +0 -0
  78. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/debug.py +0 -0
  79. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/misc.py +0 -0
  80. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/multi.py +0 -0
  81. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/reduce.py +0 -0
  82. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/split.py +0 -0
  83. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/switch.py +0 -0
  84. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/unary.py +0 -0
  85. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/utility.py +0 -0
  86. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/__init__.py +0 -0
  87. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/adagrad.py +0 -0
  88. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/adam.py +0 -0
  89. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/lion.py +0 -0
  90. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/muon.py +0 -0
  91. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/orthograd.py +0 -0
  92. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/rmsprop.py +0 -0
  93. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/rprop.py +0 -0
  94. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/shampoo.py +0 -0
  95. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/sophia_h.py +0 -0
  96. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/__init__.py +0 -0
  97. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/dct.py +0 -0
  98. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/fft.py +0 -0
  99. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/galore.py +0 -0
  100. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/structural.py +0 -0
  101. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/__init__.py +0 -0
  102. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
  103. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/olbfgs.py +0 -0
  104. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/smoothing/__init__.py +0 -0
  105. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/smoothing/gaussian.py +0 -0
  106. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/smoothing/laplacian.py +0 -0
  107. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/weight_decay/__init__.py +0 -0
  108. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/weight_decay/weight_decay.py +0 -0
  109. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/wrappers/__init__.py +0 -0
  110. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
  111. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/__init__.py +0 -0
  112. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/utility/__init__.py +0 -0
  113. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/utility/split.py +0 -0
  114. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/__init__.py +0 -0
  115. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/nevergrad.py +0 -0
  116. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/nlopt.py +0 -0
  117. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/scipy.py +0 -0
  118. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/__init__.py +0 -0
  119. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/compile.py +0 -0
  120. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/derivatives.py +0 -0
  121. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/__init__.py +0 -0
  122. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/matrix_funcs.py +0 -0
  123. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/orthogonalize.py +0 -0
  124. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/qr.py +0 -0
  125. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/svd.py +0 -0
  126. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/numberlist.py +0 -0
  127. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/ops.py +0 -0
  128. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/optimizer.py +0 -0
  129. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/optuna_tools.py +0 -0
  130. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/params.py +0 -0
  131. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/python_tools.py +0 -0
  132. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/tensorlist.py +0 -0
  133. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/torch_tools.py +0 -0
  134. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/dependency_links.txt +0 -0
  135. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/requires.txt +0 -0
  136. {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.6
3
+ Version: 0.3.9
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -156,7 +156,7 @@ for epoch in range(100):
156
156
  * `Newton`: Classic Newton's method.
157
157
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
158
158
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
- * `NystromPCG`: NewtonCG with Nyström preconditioning (my current recommendation).
159
+ * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
160
160
 
161
161
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
162
  * `LBFGS`: Limited-memory BFGS.
@@ -117,7 +117,7 @@ for epoch in range(100):
117
117
  * `Newton`: Classic Newton's method.
118
118
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
119
119
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
120
- * `NystromPCG`: NewtonCG with Nyström preconditioning (my current recommendation).
120
+ * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
121
121
 
122
122
  * **Quasi-Newton**: Approximate second-order optimization methods.
123
123
  * `LBFGS`: Limited-memory BFGS.
@@ -2,7 +2,7 @@
2
2
  # STEP 1 - COMMIT NEW CHANGES BUT DON'T PUSH THEM YET
3
3
  # STEP 2 - BUMP VERSION AND COMMIT IT (DONT PUSH!!!!)
4
4
  # STEP 3 - CREATE TAG WITH THAT VERSION
5
- # STEP 4 - PUSH CHANGES
5
+ # STEP 4 - PUSH (SYNC) CHANGES
6
6
  # STEP 5 - PUSH TAG
7
7
 
8
8
  [build-system]
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.6"
16
+ version = "0.3.9"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -745,7 +745,7 @@ SSVM = Run(
745
745
  func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
746
  sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
747
  needs_closure=True,
748
- func='rosen', steps=50, loss=1e-12, merge_invariant=True,
748
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
749
749
  sphere_steps=10, sphere_loss=0,
750
750
  )
751
751
 
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
835
835
  expected = vec_equiv_func()
836
836
 
837
837
  if isinstance(result, bool): assert result == expected
838
- else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
838
+ else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
839
839
 
840
840
 
841
841
  def test_global_vector_norm(simple_tl: TensorList):
@@ -38,17 +38,18 @@ class Preconditioner(Transform):
38
38
 
39
39
 
40
40
  def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
41
- step = self.global_state.get('step', 0)
41
+ step = self.global_state.get('__step', 0)
42
42
  states = [self.state[p] for p in params]
43
43
  settings = [self.settings[p] for p in params]
44
44
  global_settings = settings[0]
45
45
  update_freq = global_settings['__update_freq']
46
46
 
47
47
  scale_first = global_settings['__scale_first']
48
- scale_factor = 0
48
+ scale_factor = 1
49
49
  if scale_first and step == 0:
50
50
  # initial step size guess from pytorch LBFGS
51
- scale_factor = TensorList(tensors).abs().sum()
51
+ scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
52
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
52
53
 
53
54
  # update preconditioner
54
55
  if step % update_freq == 0:
@@ -63,13 +64,13 @@ class Preconditioner(Transform):
63
64
 
64
65
  # scale initial step, when preconditioner might not have been applied
65
66
  if scale_first and step == 0:
66
- torch._foreach_div_(tensors, scale_factor)
67
+ torch._foreach_mul_(tensors, scale_factor)
67
68
 
68
- self.global_state['step'] = step + 1
69
+ self.global_state['__step'] = step + 1
69
70
  return tensors
70
71
 
71
72
  def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
72
- step = self.global_state.get('step', 0)
73
+ step = self.global_state.get('__step', 0)
73
74
  tensors_vec = torch.cat([t.ravel() for t in tensors])
74
75
  params_vec = torch.cat([p.ravel() for p in params])
75
76
  grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
@@ -80,10 +81,11 @@ class Preconditioner(Transform):
80
81
  update_freq = global_settings['__update_freq']
81
82
 
82
83
  scale_first = global_settings['__scale_first']
83
- scale_factor = 0
84
+ scale_factor = 1
84
85
  if scale_first and step == 0:
85
86
  # initial step size guess from pytorch LBFGS
86
- scale_factor = tensors_vec.abs().sum()
87
+ scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
88
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
87
89
 
88
90
  # update preconditioner
89
91
  if step % update_freq == 0:
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
99
101
 
100
102
  # scale initial step, when preconditioner might not have been applied
101
103
  if scale_first and step == 0:
102
- if scale_factor >= torch.finfo(tensors_vec.dtype).eps:
103
- tensors_vec /= scale_factor
104
+ tensors_vec *= scale_factor
104
105
 
105
106
  tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
106
- self.global_state['step'] = step + 1
107
+ self.global_state['__step'] = step + 1
107
108
  return tensors
108
109
 
109
110
  @torch.no_grad
@@ -3,7 +3,7 @@ from .adadam import Adadam
3
3
  from .adamY import AdamY
4
4
  from .adasoap import AdaSOAP
5
5
  from .curveball import CurveBall
6
- from .dsoap import DSOAP
6
+ from .soapy import SOAPY
7
7
  from .gradmin import GradMin
8
8
  from .reduce_outward_lr import ReduceOutwardLR
9
9
  from .spectral import SpectralPreconditioner
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
11
11
  HistorySubspacePreconditioning,
12
12
  RandomSubspacePreconditioning,
13
13
  )
14
- from .tropical_newton import TropicalNewton
14
+ from .tropical_newton import TropicalNewton
15
+ from .newton_solver import NewtonSolver
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
140
140
  class ABSOAP(Transform):
141
141
  """SOAP but with two extra letters included in its name in order to improve converence
142
142
 
143
+ so what you can do is choose what goes into what ,and that is supposed to be good.
144
+
143
145
  new args
144
146
 
145
147
  scale by s whether to scale gradient differences by parameter differences
146
148
 
147
149
  y_to_ema2 whether to use gradient differences for exponential moving average too
150
+
151
+ okay I changed these args into another ones
152
+
153
+ BASICALLY THIS IS FOR MY EXPERIMENTS
148
154
  """
149
155
  def __init__(
150
156
  self,
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
213
219
  if 'g_prev' not in state:
214
220
  state['p_prev'] = p.clone()
215
221
  state['g_prev'] = t.clone()
216
- updates.append(tensors[i].sign())
222
+ updates.append(tensors[i].clip(-0.1,0.1))
217
223
  continue
218
224
 
219
225
  p_prev = state['p_prev']
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
285
291
  state['Q'] = get_orthogonal_matrix(state['GG'])
286
292
 
287
293
  state['step'] = 0
288
- updates.append(tensors[i].sign())
294
+ updates.append(tensors[i].clip(-0.1,0.1))
289
295
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
290
296
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
291
297
 
@@ -50,7 +50,7 @@ def adadam_(
50
50
  return None
51
51
 
52
52
  class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner and a graceful name."""
53
+ """Adam with a diagonally preconditioned preconditioner."""
54
54
  def __init__(
55
55
  self,
56
56
  beta1: float = 0.9,
@@ -37,7 +37,7 @@ def adamy_(
37
37
  p_prev.copy_(p)
38
38
  g_prev.copy_(g)
39
39
 
40
- update = g.sign().lazy_mul_(alpha*0.1)
40
+ update = g.clip(-0.1,0.1).lazy_mul_(alpha)
41
41
  if params_ is None: return update
42
42
  params_.sub_(update)
43
43
  return None
@@ -218,9 +218,9 @@ class AdaSOAP(Transform):
218
218
  state['Q'] = get_orthogonal_matrix(GG_precond)
219
219
 
220
220
  state['step'] = 0
221
- updates.append(tensors[i].sign())
221
+ updates.append(tensors[i].clip(-0.1,0.1))
222
222
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
223
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
223
+ # that can mess with other modules scaling
224
224
 
225
225
  # Projecting gradients to the eigenbases of Shampoo's preconditioner
226
226
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
71
71
 
72
72
 
73
73
  class AlgebraicNewton(Module):
74
- """newton in other algebras, not practical because solving linear system is very hard."""
74
+ """newton in other algebras, not that it works."""
75
75
  def __init__(
76
76
  self,
77
77
  reg: float | None = None,
@@ -13,7 +13,7 @@ def curveball(
13
13
  momentum: float | NumberList,
14
14
  precond_lr: float | NumberList,
15
15
  ):
16
- """returns z_, clone it!!!"""
16
+ """returns z_, clone it!!! (no just negate it)"""
17
17
  delta = Hz + tensors
18
18
  z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
19
19
  return z_
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
14
14
 
15
15
 
16
16
  class GradMin(Reformulation):
17
- """Reformulates the objective to minimize sum of gradient magnitudes via autograd.
17
+ """Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
18
18
 
19
19
  Args:
20
20
  loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
@@ -0,0 +1,88 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, overload
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, apply, Modular
7
+ from ...utils import TensorList, as_tensorlist
8
+ from ...utils.derivatives import hvp
9
+ from ..quasi_newton import LBFGS
10
+
11
+ class NewtonSolver(Module):
12
+ """Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
13
+ def __init__(
14
+ self,
15
+ solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
16
+ maxiter=None,
17
+ tol=1e-3,
18
+ reg: float = 0,
19
+ warm_start=True,
20
+ inner: Chainable | None = None,
21
+ ):
22
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
23
+ super().__init__(defaults,)
24
+
25
+ if inner is not None:
26
+ self.set_child('inner', inner)
27
+
28
+ @torch.no_grad
29
+ def step(self, vars):
30
+ params = TensorList(vars.params)
31
+ closure = vars.closure
32
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
33
+
34
+ settings = self.settings[params[0]]
35
+ solver_cls = settings['solver']
36
+ maxiter = settings['maxiter']
37
+ tol = settings['tol']
38
+ reg = settings['reg']
39
+ warm_start = settings['warm_start']
40
+
41
+ # ---------------------- Hessian vector product function --------------------- #
42
+ grad = vars.get_grad(create_graph=True)
43
+
44
+ def H_mm(x):
45
+ with torch.enable_grad():
46
+ Hvp = TensorList(hvp(params, grad, x, create_graph=True))
47
+ if reg != 0: Hvp = Hvp + (x*reg)
48
+ return Hvp
49
+
50
+ # -------------------------------- inner step -------------------------------- #
51
+ b = as_tensorlist(grad)
52
+ if 'inner' in self.children:
53
+ b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
54
+
55
+ # ---------------------------------- run cg ---------------------------------- #
56
+ x0 = None
57
+ if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
58
+ if x0 is None: x = b.zeros_like().requires_grad_(True)
59
+ else: x = x0.clone().requires_grad_(True)
60
+
61
+ solver = solver_cls(x)
62
+ def lstsq_closure(backward=True):
63
+ Hx = H_mm(x)
64
+ loss = (Hx-b).pow(2).global_mean()
65
+ if backward:
66
+ solver.zero_grad()
67
+ loss.backward(inputs=x)
68
+ return loss
69
+
70
+ if maxiter is None: maxiter = b.global_numel()
71
+ loss = None
72
+ initial_loss = lstsq_closure(False)
73
+ if initial_loss > tol:
74
+ for i in range(maxiter):
75
+ loss = solver.step(lstsq_closure)
76
+ assert loss is not None
77
+ if min(loss, loss/initial_loss) < tol: break
78
+
79
+ print(f'{loss = }')
80
+
81
+ if warm_start:
82
+ assert x0 is not None
83
+ x0.copy_(x)
84
+
85
+ vars.update = x.detach()
86
+ return vars
87
+
88
+
@@ -3,7 +3,7 @@ from operator import itemgetter
3
3
  import torch
4
4
 
5
5
  from ...core import Chainable, Transform, apply
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
6
+ from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
7
 
8
8
  @torch.no_grad
9
9
  def update_soap_covariances_(
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
135
135
 
136
136
  return final, exp_avg_sq
137
137
 
138
- class DSOAP(Transform):
138
+ class SOAPY(Transform):
139
139
  """SOAP but uses scaled gradient differences
140
140
 
141
141
  new args
@@ -195,7 +195,7 @@ class DSOAP(Transform):
195
195
  if 'g_prev' not in state:
196
196
  state['p_prev'] = p.clone()
197
197
  state['g_prev'] = t.clone()
198
- updates.append(tensors[i].sign())
198
+ updates.append(tensors[i].clip(-0.1,0.1))
199
199
  continue
200
200
 
201
201
  p_prev = state['p_prev']
@@ -228,7 +228,7 @@ class DSOAP(Transform):
228
228
  state['Q'] = get_orthogonal_matrix(state['GG'])
229
229
 
230
230
  state['step'] = 0
231
- updates.append(tensors[i].sign())
231
+ updates.append(tensors[i].clip(-0.1,0.1))
232
232
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
233
233
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
234
234
 
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
194
194
  order (int, optional):
195
195
  whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
196
196
  solver (str, optional): what to use for whitening. Defaults to 'svd'.
197
- U_beta (float | None, optional): beta for U (probably a bad idea). Defaults to None.
198
- S_beta (float | None, optional): beta for S (probably a bad idea). Defaults to None.
197
+ A_beta (float | None, optional):
198
+ beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
+ B_beta (float | None, optional):
200
+ beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
201
  interval (int, optional): How often to update history. Defaults to 1 (every step).
200
202
  concat_params (bool, optional):
201
203
  whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
275
277
  A = state.get('A', None)
276
278
  if A is None:
277
279
  # make a conservative step to avoid issues due to different GD scaling
278
- return tensor.div_(max(1, tensor.abs().sum())) # pyright:ignore[reportArgumentType]
280
+ return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
279
281
 
280
282
  B = state['B']
281
283
  update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
@@ -0,0 +1,111 @@
1
+ # idea https://arxiv.org/pdf/2212.09841
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Chainable, Module, apply
10
+ from ...utils import TensorList, vec_to_tensors
11
+ from ...utils.derivatives import (
12
+ hessian_list_to_mat,
13
+ hessian_mat,
14
+ hvp,
15
+ hvp_fd_central,
16
+ hvp_fd_forward,
17
+ jacobian_and_hessian_wrt,
18
+ )
19
+
20
+
21
+ class StructuredNewton(Module):
22
+ """TODO
23
+ Args:
24
+ structure (str, optional): structure.
25
+ reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
26
+ hvp_method (str):
27
+ how to calculate hvp_method. Defaults to "autograd".
28
+ inner (Chainable | None, optional): inner modules. Defaults to None.
29
+
30
+ """
31
+ def __init__(
32
+ self,
33
+ structure: Literal[
34
+ "diagonal",
35
+ "diagonal1",
36
+ "diagonal_abs",
37
+ "tridiagonal",
38
+ "circulant",
39
+ "toeplitz",
40
+ "toeplitz_like",
41
+ "hankel",
42
+ "rank1",
43
+ "rank2", # any rank
44
+ ]
45
+ | str = "diagonal",
46
+ reg: float = 1e-6,
47
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
+ h: float = 1e-3,
49
+ inner: Chainable | None = None,
50
+ ):
51
+ defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
52
+ super().__init__(defaults)
53
+
54
+ if inner is not None:
55
+ self.set_child('inner', inner)
56
+
57
+ @torch.no_grad
58
+ def step(self, vars):
59
+ params = TensorList(vars.params)
60
+ closure = vars.closure
61
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
62
+
63
+ settings = self.settings[params[0]]
64
+ reg = settings['reg']
65
+ hvp_method = settings['hvp_method']
66
+ structure = settings['structure']
67
+ h = settings['h']
68
+
69
+ # ------------------------ calculate grad and hessian ------------------------ #
70
+ if hvp_method == 'autograd':
71
+ grad = vars.get_grad(create_graph=True)
72
+ def Hvp_fn1(x):
73
+ return hvp(params, grad, x, retain_graph=True)
74
+ Hvp_fn = Hvp_fn1
75
+
76
+ elif hvp_method == 'forward':
77
+ grad = vars.get_grad()
78
+ def Hvp_fn2(x):
79
+ return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
80
+ Hvp_fn = Hvp_fn2
81
+
82
+ elif hvp_method == 'central':
83
+ grad = vars.get_grad()
84
+ def Hvp_fn3(x):
85
+ return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
86
+ Hvp_fn = Hvp_fn3
87
+
88
+ else: raise ValueError(hvp_method)
89
+
90
+ # -------------------------------- inner step -------------------------------- #
91
+ update = vars.get_update()
92
+ if 'inner' in self.children:
93
+ update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
94
+
95
+ # hessian
96
+ if structure.startswith('diagonal'):
97
+ H = Hvp_fn([torch.ones_like(p) for p in params])
98
+ if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
99
+ if structure == 'diagonal_abs': torch._foreach_abs_(H)
100
+ torch._foreach_add_(H, reg)
101
+ torch._foreach_div_(update, H)
102
+ vars.update = update
103
+ return vars
104
+
105
+ # hessian
106
+ raise NotImplementedError(structure)
107
+
108
+
109
+
110
+
111
+
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
38
38
  return basis @ update_projected # d
39
39
 
40
40
  class RandomSubspacePreconditioning(Transform):
41
- """full matrix rmsprop in random subspace"""
42
- def __init__(self, k: int, beta: float | None = 0.99):
43
- defaults = dict(k=k, beta=beta)
41
+ """full matrix rmsprop in random slowly changing subspace"""
42
+ def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
+ defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
44
  super().__init__(defaults, uses_grad=False)
45
45
 
46
+ if inner is not None: self.set_child('inner', inner)
47
+
46
48
  def transform(self, tensors, params, grads, vars):
47
49
  settings = self.settings[params[0]]
48
50
  g = torch.cat([t.view(-1) for t in tensors])
49
51
  k = settings['k']
50
52
  beta = settings['beta']
53
+ basis_beta = settings['basis_beta']
51
54
 
52
55
  if 'basis' not in self.global_state:
53
56
  self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
56
59
  basis = self.global_state['basis']
57
60
  accumulator = self.global_state['accumulator']
58
61
 
62
+ if basis_beta is not None:
63
+ basis.lerp_(torch.randn_like(basis), 1-basis_beta)
64
+
59
65
  update_subspace_preconditioner_(g, basis, accumulator, beta)
66
+
67
+ if 'inner' in self.children:
68
+ tensors = apply(self.children['inner'], tensors, params, grads, vars)
69
+ g = torch.cat([t.view(-1) for t in tensors])
70
+
60
71
  try:
61
72
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
62
73
  except torch.linalg.LinAlgError:
63
- denom = g.abs().sum()
64
- if denom <= 1e-10: denom = torch.ones_like(denom)
65
- preconditioned = g / g.abs().sum()
74
+ preconditioned = g.clip(-0.1, 0.1)
66
75
  vec_to_tensors_(preconditioned, tensors)
67
76
 
68
77
  return tensors
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
119
128
  try:
120
129
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
121
130
  except torch.linalg.LinAlgError:
122
- denom = g.abs().sum()
123
- if denom <= 1e-10: denom = torch.ones_like(denom)
124
- preconditioned = g / g.abs().sum()
131
+ preconditioned = g.clip(-0.1,0.1)
125
132
  vec_to_tensors_(preconditioned, tensors)
126
133
 
127
134
  return tensors
@@ -222,8 +222,7 @@ class SOAP(Transform):
222
222
  state['Q'] = get_orthogonal_matrix(state['GG'])
223
223
 
224
224
  state['step'] = 0
225
- updates.append(tensors[i].sign().div_(10))
226
- # updates.append(tensors[i] / tensors[i].abs().sum())
225
+ updates.append(tensors[i].clip(-0.1, 0.1))
227
226
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
228
227
  # I use scaled update instead as to not mess up with next modules.
229
228
 
@@ -1,4 +1,5 @@
1
1
  import math
2
+ from functools import partial
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Iterable
4
5
  from typing import Any, Literal
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
33
34
 
34
35
  return projected_closure
35
36
 
37
+ def _projected_get_grad_override(
38
+ retain_graph: bool | None = None,
39
+ create_graph: bool = False,
40
+ projection: Any = ...,
41
+ unprojected_vars: Any = ...,
42
+ self: Any = ...,
43
+ ):
44
+ assert isinstance(projection, Projection)
45
+ assert isinstance(unprojected_vars, Vars)
46
+ assert isinstance(self, Vars)
47
+
48
+ if self.grad is not None: return self.grad
49
+ grads = unprojected_vars.get_grad(retain_graph, create_graph)
50
+ projected_grads = list(projection.project(grads, self, current='grads'))
51
+ self.grad = projected_grads
52
+ for p, g in zip(self.params, projected_grads):
53
+ p.grad = g
54
+ return self.grad
55
+
36
56
 
37
57
  class Projection(Module, ABC):
38
58
  """
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
137
157
 
138
158
  # step
139
159
  projected_vars.params = self._projected_params
160
+ projected_vars.get_grad = partial(
161
+ _projected_get_grad_override,
162
+ projection=self,
163
+ unprojected_vars=vars,
164
+ self=projected_vars,
165
+ )
140
166
  projected_vars = self.children['modules'].step(projected_vars)
141
167
 
142
168
  # empty fake params storage
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
149
175
  unprojected_vars = projected_vars.clone(clone_update=False)
150
176
  unprojected_vars.closure = vars.closure
151
177
  unprojected_vars.params = vars.params
152
- if unprojected_vars.grad is None: unprojected_vars.grad = vars.grad
178
+ unprojected_vars.grad = vars.grad
153
179
 
154
180
  if self._project_update:
155
181
  assert projected_vars.update is not None