torchzero 0.3.8__tar.gz → 0.3.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {torchzero-0.3.8 → torchzero-0.3.9}/PKG-INFO +1 -1
  2. {torchzero-0.3.8 → torchzero-0.3.9}/pyproject.toml +1 -1
  3. {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_opts.py +1 -1
  4. {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_tensorlist.py +1 -1
  5. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/preconditioner.py +10 -10
  6. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/adasoap.py +1 -1
  7. torchzero-0.3.9/torchzero/modules/experimental/structured_newton.py +111 -0
  8. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/cg.py +9 -9
  9. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lbfgs.py +3 -3
  10. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lsr1.py +7 -6
  11. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/quasi_newton.py +3 -1
  12. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/newton.py +11 -6
  13. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/newton_cg.py +2 -2
  14. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/nystrom.py +6 -6
  15. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/PKG-INFO +1 -1
  16. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/SOURCES.txt +1 -0
  17. {torchzero-0.3.8 → torchzero-0.3.9}/LICENSE +0 -0
  18. {torchzero-0.3.8 → torchzero-0.3.9}/README.md +0 -0
  19. {torchzero-0.3.8 → torchzero-0.3.9}/docs/source/conf.py +0 -0
  20. {torchzero-0.3.8 → torchzero-0.3.9}/setup.cfg +0 -0
  21. {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_identical.py +0 -0
  22. {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_module.py +0 -0
  23. {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_utils_optimizer.py +0 -0
  24. {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_vars.py +0 -0
  25. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/__init__.py +0 -0
  26. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/__init__.py +0 -0
  27. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/module.py +0 -0
  28. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/transform.py +0 -0
  29. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/__init__.py +0 -0
  30. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/__init__.py +0 -0
  31. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/clipping.py +0 -0
  32. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/ema_clipping.py +0 -0
  33. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/growth_clipping.py +0 -0
  34. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/__init__.py +0 -0
  35. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/absoap.py +0 -0
  36. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/adadam.py +0 -0
  37. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/adamY.py +0 -0
  38. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/algebraic_newton.py +0 -0
  39. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/curveball.py +0 -0
  40. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/gradmin.py +0 -0
  41. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/newton_solver.py +0 -0
  42. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
  43. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/soapy.py +0 -0
  44. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/spectral.py +0 -0
  45. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/subspace_preconditioners.py +0 -0
  46. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/tropical_newton.py +0 -0
  47. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/functional.py +0 -0
  48. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/__init__.py +0 -0
  49. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/fdm.py +0 -0
  50. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
  51. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
  52. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/rfdm.py +0 -0
  53. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/__init__.py +0 -0
  54. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/backtracking.py +0 -0
  55. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/line_search.py +0 -0
  56. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/scipy.py +0 -0
  57. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/strong_wolfe.py +0 -0
  58. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/trust_region.py +0 -0
  59. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/lr/__init__.py +0 -0
  60. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/lr/lr.py +0 -0
  61. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/lr/step_size.py +0 -0
  62. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/__init__.py +0 -0
  63. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/averaging.py +0 -0
  64. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/cautious.py +0 -0
  65. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/ema.py +0 -0
  66. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/experimental.py +0 -0
  67. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/matrix_momentum.py +0 -0
  68. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/momentum.py +0 -0
  69. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/__init__.py +0 -0
  70. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/accumulate.py +0 -0
  71. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/binary.py +0 -0
  72. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/debug.py +0 -0
  73. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/misc.py +0 -0
  74. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/multi.py +0 -0
  75. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/reduce.py +0 -0
  76. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/split.py +0 -0
  77. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/switch.py +0 -0
  78. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/unary.py +0 -0
  79. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/utility.py +0 -0
  80. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/__init__.py +0 -0
  81. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/adagrad.py +0 -0
  82. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/adam.py +0 -0
  83. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/lion.py +0 -0
  84. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/muon.py +0 -0
  85. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/orthograd.py +0 -0
  86. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/rmsprop.py +0 -0
  87. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/rprop.py +0 -0
  88. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/shampoo.py +0 -0
  89. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/soap.py +0 -0
  90. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/sophia_h.py +0 -0
  91. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/__init__.py +0 -0
  92. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/dct.py +0 -0
  93. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/fft.py +0 -0
  94. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/galore.py +0 -0
  95. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/projection.py +0 -0
  96. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/structural.py +0 -0
  97. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/__init__.py +0 -0
  98. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
  99. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -0
  100. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/olbfgs.py +0 -0
  101. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/__init__.py +0 -0
  102. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/smoothing/__init__.py +0 -0
  103. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/smoothing/gaussian.py +0 -0
  104. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/smoothing/laplacian.py +0 -0
  105. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/weight_decay/__init__.py +0 -0
  106. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/weight_decay/weight_decay.py +0 -0
  107. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/wrappers/__init__.py +0 -0
  108. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
  109. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/__init__.py +0 -0
  110. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/utility/__init__.py +0 -0
  111. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/utility/split.py +0 -0
  112. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/__init__.py +0 -0
  113. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/nevergrad.py +0 -0
  114. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/nlopt.py +0 -0
  115. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/scipy.py +0 -0
  116. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/__init__.py +0 -0
  117. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/compile.py +0 -0
  118. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/derivatives.py +0 -0
  119. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/__init__.py +0 -0
  120. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/benchmark.py +0 -0
  121. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/matrix_funcs.py +0 -0
  122. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/orthogonalize.py +0 -0
  123. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/qr.py +0 -0
  124. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/solve.py +0 -0
  125. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/svd.py +0 -0
  126. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/numberlist.py +0 -0
  127. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/ops.py +0 -0
  128. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/optimizer.py +0 -0
  129. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/optuna_tools.py +0 -0
  130. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/params.py +0 -0
  131. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/python_tools.py +0 -0
  132. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/tensorlist.py +0 -0
  133. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/torch_tools.py +0 -0
  134. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/dependency_links.txt +0 -0
  135. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/requires.txt +0 -0
  136. {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.8
3
+ Version: 0.3.9
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.8"
16
+ version = "0.3.9"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -745,7 +745,7 @@ SSVM = Run(
745
745
  func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
746
  sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
747
  needs_closure=True,
748
- func='rosen', steps=50, loss=0.02, merge_invariant=True,
748
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
749
749
  sphere_steps=10, sphere_loss=0,
750
750
  )
751
751
 
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
835
835
  expected = vec_equiv_func()
836
836
 
837
837
  if isinstance(result, bool): assert result == expected
838
- else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
838
+ else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
839
839
 
840
840
 
841
841
  def test_global_vector_norm(simple_tl: TensorList):
@@ -45,12 +45,11 @@ class Preconditioner(Transform):
45
45
  update_freq = global_settings['__update_freq']
46
46
 
47
47
  scale_first = global_settings['__scale_first']
48
- scale_factor = 0
48
+ scale_factor = 1
49
49
  if scale_first and step == 0:
50
- # initial step size guess from pytorch LBFGS was too unstable
51
- # I switched to norm
52
- tensors = TensorList(tensors)
53
- scale_factor = tensors.abs().global_mean().clip(min=1)
50
+ # initial step size guess from pytorch LBFGS
51
+ scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
52
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
54
53
 
55
54
  # update preconditioner
56
55
  if step % update_freq == 0:
@@ -65,7 +64,7 @@ class Preconditioner(Transform):
65
64
 
66
65
  # scale initial step, when preconditioner might not have been applied
67
66
  if scale_first and step == 0:
68
- torch._foreach_div_(tensors, scale_factor)
67
+ torch._foreach_mul_(tensors, scale_factor)
69
68
 
70
69
  self.global_state['__step'] = step + 1
71
70
  return tensors
@@ -82,10 +81,11 @@ class Preconditioner(Transform):
82
81
  update_freq = global_settings['__update_freq']
83
82
 
84
83
  scale_first = global_settings['__scale_first']
85
- scale_factor = 0
84
+ scale_factor = 1
86
85
  if scale_first and step == 0:
87
- # initial step size guess from pytorch LBFGS was too unstable
88
- scale_factor = tensors_vec.abs().mean().clip(min=1)
86
+ # initial step size guess from pytorch LBFGS
87
+ scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
88
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
89
89
 
90
90
  # update preconditioner
91
91
  if step % update_freq == 0:
@@ -101,7 +101,7 @@ class Preconditioner(Transform):
101
101
 
102
102
  # scale initial step, when preconditioner might not have been applied
103
103
  if scale_first and step == 0:
104
- tensors_vec /= scale_factor
104
+ tensors_vec *= scale_factor
105
105
 
106
106
  tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
107
107
  self.global_state['__step'] = step + 1
@@ -220,7 +220,7 @@ class AdaSOAP(Transform):
220
220
  state['step'] = 0
221
221
  updates.append(tensors[i].clip(-0.1,0.1))
222
222
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
223
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
223
+ # that can mess with other modules scaling
224
224
 
225
225
  # Projecting gradients to the eigenbases of Shampoo's preconditioner
226
226
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -0,0 +1,111 @@
1
+ # idea https://arxiv.org/pdf/2212.09841
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Chainable, Module, apply
10
+ from ...utils import TensorList, vec_to_tensors
11
+ from ...utils.derivatives import (
12
+ hessian_list_to_mat,
13
+ hessian_mat,
14
+ hvp,
15
+ hvp_fd_central,
16
+ hvp_fd_forward,
17
+ jacobian_and_hessian_wrt,
18
+ )
19
+
20
+
21
+ class StructuredNewton(Module):
22
+ """TODO
23
+ Args:
24
+ structure (str, optional): structure.
25
+ reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
26
+ hvp_method (str):
27
+ how to calculate hvp_method. Defaults to "autograd".
28
+ inner (Chainable | None, optional): inner modules. Defaults to None.
29
+
30
+ """
31
+ def __init__(
32
+ self,
33
+ structure: Literal[
34
+ "diagonal",
35
+ "diagonal1",
36
+ "diagonal_abs",
37
+ "tridiagonal",
38
+ "circulant",
39
+ "toeplitz",
40
+ "toeplitz_like",
41
+ "hankel",
42
+ "rank1",
43
+ "rank2", # any rank
44
+ ]
45
+ | str = "diagonal",
46
+ reg: float = 1e-6,
47
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
+ h: float = 1e-3,
49
+ inner: Chainable | None = None,
50
+ ):
51
+ defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
52
+ super().__init__(defaults)
53
+
54
+ if inner is not None:
55
+ self.set_child('inner', inner)
56
+
57
+ @torch.no_grad
58
+ def step(self, vars):
59
+ params = TensorList(vars.params)
60
+ closure = vars.closure
61
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
62
+
63
+ settings = self.settings[params[0]]
64
+ reg = settings['reg']
65
+ hvp_method = settings['hvp_method']
66
+ structure = settings['structure']
67
+ h = settings['h']
68
+
69
+ # ------------------------ calculate grad and hessian ------------------------ #
70
+ if hvp_method == 'autograd':
71
+ grad = vars.get_grad(create_graph=True)
72
+ def Hvp_fn1(x):
73
+ return hvp(params, grad, x, retain_graph=True)
74
+ Hvp_fn = Hvp_fn1
75
+
76
+ elif hvp_method == 'forward':
77
+ grad = vars.get_grad()
78
+ def Hvp_fn2(x):
79
+ return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
80
+ Hvp_fn = Hvp_fn2
81
+
82
+ elif hvp_method == 'central':
83
+ grad = vars.get_grad()
84
+ def Hvp_fn3(x):
85
+ return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
86
+ Hvp_fn = Hvp_fn3
87
+
88
+ else: raise ValueError(hvp_method)
89
+
90
+ # -------------------------------- inner step -------------------------------- #
91
+ update = vars.get_update()
92
+ if 'inner' in self.children:
93
+ update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
94
+
95
+ # hessian
96
+ if structure.startswith('diagonal'):
97
+ H = Hvp_fn([torch.ones_like(p) for p in params])
98
+ if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
99
+ if structure == 'diagonal_abs': torch._foreach_abs_(H)
100
+ torch._foreach_add_(H, reg)
101
+ torch._foreach_div_(update, H)
102
+ vars.update = update
103
+ return vars
104
+
105
+ # hessian
106
+ raise NotImplementedError(structure)
107
+
108
+
109
+
110
+
111
+
@@ -64,7 +64,7 @@ class ConguateGradientBase(Transform, ABC):
64
64
  # ------------------------------- Polak-Ribière ------------------------------ #
65
65
  def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
66
66
  denom = prev_g.dot(prev_g)
67
- if denom == 0: return 0
67
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
68
68
  return g.dot(g - prev_g) / denom
69
69
 
70
70
  class PolakRibiere(ConguateGradientBase):
@@ -76,8 +76,8 @@ class PolakRibiere(ConguateGradientBase):
76
76
  return polak_ribiere_beta(g, prev_g)
77
77
 
78
78
  # ------------------------------ Fletcher–Reeves ----------------------------- #
79
- def fletcher_reeves_beta(gg, prev_gg):
80
- if prev_gg == 0: return 0
79
+ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
80
+ if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
81
81
  return gg / prev_gg
82
82
 
83
83
  class FletcherReeves(ConguateGradientBase):
@@ -98,7 +98,7 @@ class FletcherReeves(ConguateGradientBase):
98
98
  def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
99
99
  grad_diff = g - prev_g
100
100
  denom = prev_d.dot(grad_diff)
101
- if denom == 0: return 0
101
+ if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
102
102
  return (g.dot(grad_diff) / denom).neg()
103
103
 
104
104
 
@@ -114,7 +114,7 @@ class HestenesStiefel(ConguateGradientBase):
114
114
  # --------------------------------- Dai–Yuan --------------------------------- #
115
115
  def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
116
116
  denom = prev_d.dot(g - prev_g)
117
- if denom == 0: return 0
117
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
118
118
  return (g.dot(g) / denom).neg()
119
119
 
120
120
  class DaiYuan(ConguateGradientBase):
@@ -129,7 +129,7 @@ class DaiYuan(ConguateGradientBase):
129
129
  # -------------------------------- Liu-Storey -------------------------------- #
130
130
  def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
131
131
  denom = prev_g.dot(prev_d)
132
- if denom == 0: return 0
132
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
133
133
  return g.dot(g - prev_g) / denom
134
134
 
135
135
  class LiuStorey(ConguateGradientBase):
@@ -159,7 +159,7 @@ class ConjugateDescent(Transform):
159
159
  self.global_state['denom'] = torch.tensor(0.).to(g[0])
160
160
 
161
161
  prev_gd = self.global_state.get('prev_gd', 0)
162
- if prev_gd == 0: beta = 0
162
+ if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
163
163
  else: beta = g.dot(g) / prev_gd
164
164
 
165
165
  # inner step
@@ -176,7 +176,7 @@ class ConjugateDescent(Transform):
176
176
  def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
177
177
  g_diff = g - prev_g
178
178
  denom = prev_d.dot(g_diff)
179
- if denom == 0: return 0
179
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
180
180
 
181
181
  term1 = 1/denom
182
182
  # term2
@@ -198,7 +198,7 @@ class HagerZhang(ConguateGradientBase):
198
198
  def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
199
199
  grad_diff = g - prev_g
200
200
  denom = prev_d.dot(grad_diff)
201
- if denom == 0: return 0
201
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
202
202
 
203
203
  # Dai-Yuan
204
204
  dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
@@ -38,9 +38,9 @@ def lbfgs(
38
38
  if len(s_history) == 0 or y_k is None or ys_k is None:
39
39
 
40
40
  # initial step size guess modified from pytorch L-BFGS
41
- scale = 1 / tensors_.abs().global_sum()
42
- if scale < 1e-5: scale = 1 / tensors_.abs().mean()
43
- return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
41
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
42
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
43
+ return tensors_.mul_(scale_factor)
44
44
 
45
45
  else:
46
46
  # 1st loop
@@ -17,9 +17,9 @@ def lsr1_(
17
17
  ):
18
18
  if step == 0 or not s_history:
19
19
  # initial step size guess from pytorch
20
- scale = 1 / tensors_.abs().global_sum()
21
- if scale < 1e-5: scale = 1 / tensors_.abs().mean()
22
- return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
20
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
21
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
22
+ return tensors_.mul_(scale_factor)
23
23
 
24
24
  m = len(s_history)
25
25
 
@@ -65,9 +65,10 @@ def lsr1_(
65
65
  Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
66
66
 
67
67
  if scale_second and step == 1:
68
- scale = 1 / tensors_.abs().global_sum()
69
- if scale < 1e-5: scale = 1 / tensors_.abs().mean()
70
- Hx.mul_(min(1.0, scale)) # pyright:ignore[reportArgumentType]
68
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
69
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
70
+ Hx.mul_(scale_factor)
71
+
71
72
  return Hx
72
73
 
73
74
 
@@ -122,7 +122,9 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
122
122
  step = state.get('step', 0)
123
123
 
124
124
  if settings['scale_second'] and step == 2:
125
- tensor = tensor / tensor.abs().mean().clip(min=1)
125
+ scale_factor = 1 / tensor.abs().sum().clip(min=1)
126
+ scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
127
+ tensor = tensor * scale_factor
126
128
 
127
129
  inverse = settings['inverse']
128
130
  if inverse:
@@ -1,14 +1,18 @@
1
1
  import warnings
2
+ from collections.abc import Callable
2
3
  from functools import partial
3
4
  from typing import Literal
4
- from collections.abc import Callable
5
+
5
6
  import torch
6
7
 
7
- from ...core import Chainable, apply, Module
8
- from ...utils import vec_to_tensors, TensorList
8
+ from ...core import Chainable, Module, apply
9
+ from ...utils import TensorList, vec_to_tensors
9
10
  from ...utils.derivatives import (
10
11
  hessian_list_to_mat,
11
12
  hessian_mat,
13
+ hvp,
14
+ hvp_fd_central,
15
+ hvp_fd_forward,
12
16
  jacobian_and_hessian_wrt,
13
17
  )
14
18
 
@@ -117,9 +121,10 @@ class Newton(Module):
117
121
  raise ValueError(hessian_method)
118
122
 
119
123
  # -------------------------------- inner step -------------------------------- #
124
+ update = vars.get_update()
120
125
  if 'inner' in self.children:
121
- g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
122
- g = torch.cat([t.view(-1) for t in g_list])
126
+ update = apply(self.children['inner'], update, params=params, grads=list(g_list), vars=vars)
127
+ g = torch.cat([t.view(-1) for t in update])
123
128
 
124
129
  # ------------------------------- regulazition ------------------------------- #
125
130
  if eig_reg: H = eig_tikhonov_(H, reg)
@@ -139,4 +144,4 @@ class Newton(Module):
139
144
  if update is None: update = least_squares_solve(H, g)
140
145
 
141
146
  vars.update = vec_to_tensors(update, params)
142
- return vars
147
+ return vars
@@ -66,9 +66,9 @@ class NewtonCG(Module):
66
66
 
67
67
 
68
68
  # -------------------------------- inner step -------------------------------- #
69
- b = grad
69
+ b = vars.get_update()
70
70
  if 'inner' in self.children:
71
- b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
71
+ b = as_tensorlist(apply(self.children['inner'], b, params=params, grads=grad, vars=vars))
72
72
 
73
73
  # ---------------------------------- run cg ---------------------------------- #
74
74
  x0 = None
@@ -15,7 +15,7 @@ class NystromSketchAndSolve(Module):
15
15
  rank: int,
16
16
  reg: float = 1e-3,
17
17
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
- h=1e-3,
18
+ h=1e-2,
19
19
  inner: Chainable | None = None,
20
20
  seed: int | None = None,
21
21
  ):
@@ -74,9 +74,9 @@ class NystromSketchAndSolve(Module):
74
74
 
75
75
 
76
76
  # -------------------------------- inner step -------------------------------- #
77
- b = grad
77
+ b = vars.get_update()
78
78
  if 'inner' in self.children:
79
- b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
79
+ b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
80
80
 
81
81
  # ------------------------------ sketch&n&solve ------------------------------ #
82
82
  x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
@@ -93,7 +93,7 @@ class NystromPCG(Module):
93
93
  tol=1e-3,
94
94
  reg: float = 1e-6,
95
95
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
96
- h=1e-3,
96
+ h=1e-2,
97
97
  inner: Chainable | None = None,
98
98
  seed: int | None = None,
99
99
  ):
@@ -156,9 +156,9 @@ class NystromPCG(Module):
156
156
 
157
157
 
158
158
  # -------------------------------- inner step -------------------------------- #
159
- b = grad
159
+ b = vars.get_update()
160
160
  if 'inner' in self.children:
161
- b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
161
+ b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
162
162
 
163
163
  # ------------------------------ sketch&n&solve ------------------------------ #
164
164
  x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.8
3
+ Version: 0.3.9
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -36,6 +36,7 @@ torchzero/modules/experimental/newton_solver.py
36
36
  torchzero/modules/experimental/reduce_outward_lr.py
37
37
  torchzero/modules/experimental/soapy.py
38
38
  torchzero/modules/experimental/spectral.py
39
+ torchzero/modules/experimental/structured_newton.py
39
40
  torchzero/modules/experimental/subspace_preconditioners.py
40
41
  torchzero/modules/experimental/tropical_newton.py
41
42
  torchzero/modules/grad_approximation/__init__.py
File without changes
File without changes
File without changes
File without changes
File without changes