torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,17 @@
1
- from collections.abc import Callable
2
- from contextlib import nullcontext
3
1
  from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Mapping
3
+ from typing import Any
4
+
4
5
  import numpy as np
5
6
  import torch
6
7
 
7
- from ...core import Chainable, Module, apply_transform, Var
8
- from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
9
- from ...utils.derivatives import (
10
- flatten_jacobian,
11
- jacobian_wrt,
12
- )
8
+ from ...core import Chainable, DerivativesMethod, Objective, Transform
9
+ from ...utils import TensorList, vec_to_tensors
10
+
13
11
 
14
- class HigherOrderMethodBase(Module, ABC):
15
- def __init__(self, defaults: dict | None = None, vectorize: bool = True):
16
- self._vectorize = vectorize
12
+ class HigherOrderMethodBase(Transform, ABC):
13
+ def __init__(self, defaults: dict | None = None, derivatives_method: DerivativesMethod = 'batched_autograd'):
14
+ self._derivatives_method: DerivativesMethod = derivatives_method
17
15
  super().__init__(defaults)
18
16
 
19
17
  @abstractmethod
@@ -21,61 +19,27 @@ class HigherOrderMethodBase(Module, ABC):
21
19
  self,
22
20
  x: torch.Tensor,
23
21
  evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
24
- var: Var,
22
+ objective: Objective,
23
+ setting: Mapping[str, Any],
25
24
  ) -> torch.Tensor:
26
25
  """"""
27
26
 
28
27
  @torch.no_grad
29
- def step(self, var):
30
- params = TensorList(var.params)
31
- x0 = params.clone()
32
- closure = var.closure
28
+ def apply_states(self, objective, states, settings):
29
+ params = TensorList(objective.params)
30
+
31
+ closure = objective.closure
33
32
  if closure is None: raise RuntimeError('MultipointNewton requires closure')
34
- vectorize = self._vectorize
33
+ derivatives_method = self._derivatives_method
35
34
 
36
35
  def evaluate(x, order) -> tuple[torch.Tensor, ...]:
37
36
  """order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
38
- params.from_vec_(x)
39
-
40
- if order == 0:
41
- loss = closure(False)
42
- params.copy_(x0)
43
- return (loss, )
44
-
45
- if order == 1:
46
- with torch.enable_grad():
47
- loss = closure()
48
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
49
- params.copy_(x0)
50
- return loss, torch.cat([g.ravel() for g in grad])
51
-
52
- with torch.enable_grad():
53
- loss = var.loss = var.loss_approx = closure(False)
54
-
55
- g_list = torch.autograd.grad(loss, params, create_graph=True)
56
- var.grad = list(g_list)
57
-
58
- g = torch.cat([t.ravel() for t in g_list])
59
- n = g.numel()
60
- ret = [loss, g]
61
- T = g # current derivatives tensor
62
-
63
- # get all derivative up to order
64
- for o in range(2, order + 1):
65
- is_last = o == order
66
- T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
67
- with torch.no_grad() if is_last else nullcontext():
68
- # the shape is (ndim, ) * order
69
- T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
70
- ret.append(T)
71
-
72
- params.copy_(x0)
73
- return tuple(ret)
37
+ return objective.derivatives_at(x, order, method=derivatives_method)
74
38
 
75
39
  x = torch.cat([p.ravel() for p in params])
76
- dir = self.one_iteration(x, evaluate, var)
77
- var.update = vec_to_tensors(dir, var.params)
78
- return var
40
+ dir = self.one_iteration(x, evaluate, objective, settings[0])
41
+ objective.updates = vec_to_tensors(dir, objective.params)
42
+ return objective
79
43
 
80
44
  def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
81
45
  if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
@@ -106,16 +70,15 @@ class SixthOrder3P(HigherOrderMethodBase):
106
70
 
107
71
  Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
108
72
  """
109
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
73
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
110
74
  defaults=dict(lstsq=lstsq)
111
- super().__init__(defaults=defaults, vectorize=vectorize)
75
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
112
76
 
113
- def one_iteration(self, x, evaluate, var):
114
- settings = self.defaults
115
- lstsq = settings['lstsq']
77
+ @torch.no_grad
78
+ def one_iteration(self, x, evaluate, objective, setting):
116
79
  def f(x): return evaluate(x, 1)[1]
117
80
  def f_j(x): return evaluate(x, 2)[1:]
118
- x_star = sixth_order_3p(x, f, f_j, lstsq)
81
+ x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
119
82
  return x - x_star
120
83
 
121
84
  # I don't think it works (I tested root finding with this and it goes all over the place)
@@ -173,15 +136,14 @@ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
173
136
 
174
137
  class SixthOrder5P(HigherOrderMethodBase):
175
138
  """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
176
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
139
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
177
140
  defaults=dict(lstsq=lstsq)
178
- super().__init__(defaults=defaults, vectorize=vectorize)
141
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
179
142
 
180
- def one_iteration(self, x, evaluate, var):
181
- settings = self.defaults
182
- lstsq = settings['lstsq']
143
+ @torch.no_grad
144
+ def one_iteration(self, x, evaluate, objective, setting):
183
145
  def f_j(x): return evaluate(x, 2)[1:]
184
- x_star = sixth_order_5p(x, f_j, lstsq)
146
+ x_star = sixth_order_5p(x, f_j, setting['lstsq'])
185
147
  return x - x_star
186
148
 
187
149
  # 2f 1J 2 solves
@@ -196,16 +158,15 @@ class TwoPointNewton(HigherOrderMethodBase):
196
158
  """two-point Newton method with frozen derivative with third order convergence.
197
159
 
198
160
  Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
199
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
161
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
200
162
  defaults=dict(lstsq=lstsq)
201
- super().__init__(defaults=defaults, vectorize=vectorize)
163
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
202
164
 
203
- def one_iteration(self, x, evaluate, var):
204
- settings = self.defaults
205
- lstsq = settings['lstsq']
165
+ @torch.no_grad
166
+ def one_iteration(self, x, evaluate, objective, setting):
206
167
  def f(x): return evaluate(x, 1)[1]
207
168
  def f_j(x): return evaluate(x, 2)[1:]
208
- x_star = two_point_newton(x, f, f_j, lstsq)
169
+ x_star = two_point_newton(x, f, f_j, setting['lstsq'])
209
170
  return x - x_star
210
171
 
211
172
  #3f 2J 1inv
@@ -224,15 +185,14 @@ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
224
185
 
225
186
  class SixthOrder3PM2(HigherOrderMethodBase):
226
187
  """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
227
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
188
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
228
189
  defaults=dict(lstsq=lstsq)
229
- super().__init__(defaults=defaults, vectorize=vectorize)
190
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
230
191
 
231
- def one_iteration(self, x, evaluate, var):
232
- settings = self.defaults
233
- lstsq = settings['lstsq']
192
+ @torch.no_grad
193
+ def one_iteration(self, x, evaluate, objective, setting):
234
194
  def f_j(x): return evaluate(x, 2)[1:]
235
195
  def f(x): return evaluate(x, 1)[1]
236
- x_star = sixth_order_3pm2(x, f, f_j, lstsq)
196
+ x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
237
197
  return x - x_star
238
198
 
@@ -1,21 +1,12 @@
1
- import warnings
2
1
  from collections.abc import Callable
3
- from functools import partial
4
2
  from typing import Literal
5
3
 
6
4
  import torch
7
5
 
8
- from ...core import Chainable, Module, apply_transform
9
- from ...utils import TensorList, vec_to_tensors
10
- from ...utils.derivatives import (
11
- flatten_jacobian,
12
- hessian_mat,
13
- hvp,
14
- hvp_fd_central,
15
- hvp_fd_forward,
16
- jacobian_and_hessian_wrt,
17
- )
18
- from ...utils.linalg.linear_operator import DenseWithInverse, Dense
6
+ from ...core import Chainable, Transform, Objective, HessianMethod, Module
7
+ from ...utils import vec_to_tensors
8
+ from ...linalg.linear_operator import Dense, DenseWithInverse
9
+
19
10
 
20
11
  def _lu_solve(H: torch.Tensor, g: torch.Tensor):
21
12
  try:
@@ -26,10 +17,9 @@ def _lu_solve(H: torch.Tensor, g: torch.Tensor):
26
17
  return None
27
18
 
28
19
  def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
29
- x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
20
+ L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
30
21
  if info == 0:
31
- g.unsqueeze_(1)
32
- return torch.cholesky_solve(g, x)
22
+ return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
33
23
  return None
34
24
 
35
25
  def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
@@ -49,10 +39,59 @@ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_n
49
39
  except torch.linalg.LinAlgError:
50
40
  return None
51
41
 
42
+ def _newton_step(objective: Objective, H: torch.Tensor, damping:float, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None, no_inner: Module | None = None) -> torch.Tensor:
43
+ """INNER SHOULD BE NONE IN MOST CASES! Because Transform already has inner.
44
+ Returns the update tensor, then do vec_to_tensor(update, params)"""
45
+ # -------------------------------- inner step -------------------------------- #
46
+ if no_inner is not None:
47
+ objective = no_inner.step(objective)
48
+
49
+ update = objective.get_updates()
50
+
51
+ g = torch.cat([t.ravel() for t in update])
52
+ if g_proj is not None: g = g_proj(g)
53
+
54
+ # ----------------------------------- solve ---------------------------------- #
55
+ update = None
56
+
57
+ if damping != 0:
58
+ H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
59
+
60
+ if H_tfm is not None:
61
+ ret = H_tfm(H, g)
62
+
63
+ if isinstance(ret, torch.Tensor):
64
+ update = ret
65
+
66
+ else: # returns (H, is_inv)
67
+ H, is_inv = ret
68
+ if is_inv: update = H @ g
69
+
70
+ if eigval_fn is not None:
71
+ update = _eigh_solve(H, g, eigval_fn, search_negative=False)
72
+
73
+ if update is None and use_lstsq: update = _least_squares_solve(H, g)
74
+ if update is None: update = _cholesky_solve(H, g)
75
+ if update is None: update = _lu_solve(H, g)
76
+ if update is None: update = _least_squares_solve(H, g)
77
+
78
+ return update
79
+
80
+ def _get_H(H: torch.Tensor, eigval_fn):
81
+ if eigval_fn is not None:
82
+ try:
83
+ L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
84
+ L: torch.Tensor = eigval_fn(L)
85
+ H = Q @ L.diag_embed() @ Q.mH
86
+ H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
87
+ return DenseWithInverse(H, H_inv)
52
88
 
89
+ except torch.linalg.LinAlgError:
90
+ pass
53
91
 
92
+ return Dense(H)
54
93
 
55
- class Newton(Module):
94
+ class Newton(Transform):
56
95
  """Exact newton's method via autograd.
57
96
 
58
97
  Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
@@ -60,7 +99,7 @@ class Newton(Module):
60
99
  ``g`` can be output of another module, if it is specifed in ``inner`` argument.
61
100
 
62
101
  Note:
63
- In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
102
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
64
103
 
65
104
  Note:
66
105
  This module requires the a closure passed to the optimizer step,
@@ -77,11 +116,6 @@ class Newton(Module):
77
116
  when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
78
117
  If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
79
118
  argument will be ignored.
80
- hessian_method (str):
81
- how to calculate hessian. Defaults to "autograd".
82
- vectorize (bool, optional):
83
- whether to enable vectorized hessian. Defaults to True.
84
- inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
85
119
  H_tfm (Callable | None, optional):
86
120
  optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
87
121
 
@@ -94,6 +128,22 @@ class Newton(Module):
94
128
  eigval_fn (Callable | None, optional):
95
129
  optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
96
130
  If this is specified, eigendecomposition will be used to invert the hessian.
131
+ hessian_method (str):
132
+ Determines how hessian is computed.
133
+
134
+ - ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
135
+ - ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
136
+ - ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
137
+ - ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
138
+ - ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
139
+ - ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
140
+ - ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
141
+ - ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
142
+
143
+ Defaults to ``"batched_autograd"``.
144
+ h (float, optional):
145
+ finite difference step size for "fd_forward" and "fd_central".
146
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
97
147
 
98
148
  # See also
99
149
 
@@ -111,10 +161,9 @@ class Newton(Module):
111
161
  The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
112
162
  Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
113
163
 
114
- Additionally, if ``eigval_fn`` is specified or ``search_negative`` is ``True``,
115
- eigendecomposition of the hessian is computed, ``eigval_fn`` is applied to the eigenvalues,
116
- and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues.
117
- This is more generally more computationally expensive.
164
+ Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
165
+ ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive,
166
+ but not by much
118
167
 
119
168
  ## Handling non-convexity
120
169
 
@@ -167,217 +216,45 @@ class Newton(Module):
167
216
  def __init__(
168
217
  self,
169
218
  damping: float = 0,
170
- search_negative: bool = False,
171
219
  use_lstsq: bool = False,
172
220
  update_freq: int = 1,
173
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
174
- vectorize: bool = True,
175
- inner: Chainable | None = None,
176
221
  H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
177
222
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
223
+ hessian_method: HessianMethod = "batched_autograd",
224
+ h: float = 1e-3,
225
+ inner: Chainable | None = None,
178
226
  ):
179
- defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, search_negative=search_negative, update_freq=update_freq)
180
- super().__init__(defaults)
181
-
182
- if inner is not None:
183
- self.set_child('inner', inner)
184
-
185
- @torch.no_grad
186
- def update(self, var):
187
- params = TensorList(var.params)
188
- closure = var.closure
189
- if closure is None: raise RuntimeError('NewtonCG requires closure')
190
-
191
- settings = self.settings[params[0]]
192
- damping = settings['damping']
193
- hessian_method = settings['hessian_method']
194
- vectorize = settings['vectorize']
195
- update_freq = settings['update_freq']
196
-
197
- step = self.global_state.get('step', 0)
198
- self.global_state['step'] = step + 1
199
-
200
- g_list = var.grad
201
- H = None
202
- if step % update_freq == 0:
203
- # ------------------------ calculate grad and hessian ------------------------ #
204
- if hessian_method == 'autograd':
205
- with torch.enable_grad():
206
- loss = var.loss = var.loss_approx = closure(False)
207
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
208
- g_list = [t[0] for t in g_list] # remove leading dim from loss
209
- var.grad = g_list
210
- H = flatten_jacobian(H_list)
211
-
212
- elif hessian_method in ('func', 'autograd.functional'):
213
- strat = 'forward-mode' if vectorize else 'reverse-mode'
214
- with torch.enable_grad():
215
- g_list = var.get_grad(retain_graph=True)
216
- H = hessian_mat(partial(closure, backward=False), params,
217
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
218
-
219
- else:
220
- raise ValueError(hessian_method)
221
-
222
- if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
223
- self.global_state['H'] = H
227
+ defaults = locals().copy()
228
+ del defaults['self'], defaults['update_freq'], defaults["inner"]
229
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
224
230
 
225
231
  @torch.no_grad
226
- def apply(self, var):
227
- H = self.global_state["H"]
228
-
229
- params = var.params
230
- settings = self.settings[params[0]]
231
- search_negative = settings['search_negative']
232
- H_tfm = settings['H_tfm']
233
- eigval_fn = settings['eigval_fn']
234
- use_lstsq = settings['use_lstsq']
235
-
236
- # -------------------------------- inner step -------------------------------- #
237
- update = var.get_update()
238
- if 'inner' in self.children:
239
- update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
240
-
241
- g = torch.cat([t.ravel() for t in update])
232
+ def update_states(self, objective, states, settings):
233
+ fs = settings[0]
242
234
 
243
- # ----------------------------------- solve ---------------------------------- #
244
- update = None
245
- if H_tfm is not None:
246
- ret = H_tfm(H, g)
235
+ _, _, self.global_state['H'] = objective.hessian(
236
+ hessian_method=fs['hessian_method'],
237
+ h=fs['h'],
238
+ at_x0=True
239
+ )
247
240
 
248
- if isinstance(ret, torch.Tensor):
249
- update = ret
250
-
251
- else: # returns (H, is_inv)
252
- H, is_inv = ret
253
- if is_inv: update = H @ g
254
-
255
- if search_negative or (eigval_fn is not None):
256
- update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
257
-
258
- if update is None and use_lstsq: update = _least_squares_solve(H, g)
259
- if update is None: update = _cholesky_solve(H, g)
260
- if update is None: update = _lu_solve(H, g)
261
- if update is None: update = _least_squares_solve(H, g)
262
-
263
- var.update = vec_to_tensors(update, params)
264
-
265
- return var
266
-
267
- def get_H(self,var):
268
- H = self.global_state["H"]
269
- settings = self.defaults
270
- if settings['eigval_fn'] is not None:
271
- try:
272
- L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
273
- L = settings['eigval_fn'](L)
274
- H = Q @ L.diag_embed() @ Q.mH
275
- H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
276
- return DenseWithInverse(H, H_inv)
277
-
278
- except torch.linalg.LinAlgError:
279
- pass
280
-
281
- return Dense(H)
282
-
283
-
284
- class InverseFreeNewton(Module):
285
- """Inverse-free newton's method
286
-
287
- .. note::
288
- In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
289
-
290
- .. note::
291
- This module requires the a closure passed to the optimizer step,
292
- as it needs to re-evaluate the loss and gradients for calculating the hessian.
293
- The closure must accept a ``backward`` argument (refer to documentation).
241
+ @torch.no_grad
242
+ def apply_states(self, objective, states, settings):
243
+ params = objective.params
244
+ fs = settings[0]
294
245
 
295
- .. warning::
296
- this uses roughly O(N^2) memory.
246
+ update = _newton_step(
247
+ objective=objective,
248
+ H = self.global_state["H"],
249
+ damping = fs["damping"],
250
+ H_tfm = fs["H_tfm"],
251
+ eigval_fn = fs["eigval_fn"],
252
+ use_lstsq = fs["use_lstsq"],
253
+ )
297
254
 
298
- Reference
299
- Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
300
- """
301
- def __init__(
302
- self,
303
- update_freq: int = 1,
304
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
305
- vectorize: bool = True,
306
- inner: Chainable | None = None,
307
- ):
308
- defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
309
- super().__init__(defaults)
255
+ objective.updates = vec_to_tensors(update, params)
256
+ return objective
310
257
 
311
- if inner is not None:
312
- self.set_child('inner', inner)
258
+ def get_H(self,objective=...):
259
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
313
260
 
314
- @torch.no_grad
315
- def update(self, var):
316
- params = TensorList(var.params)
317
- closure = var.closure
318
- if closure is None: raise RuntimeError('NewtonCG requires closure')
319
-
320
- settings = self.settings[params[0]]
321
- hessian_method = settings['hessian_method']
322
- vectorize = settings['vectorize']
323
- update_freq = settings['update_freq']
324
-
325
- step = self.global_state.get('step', 0)
326
- self.global_state['step'] = step + 1
327
-
328
- g_list = var.grad
329
- Y = None
330
- if step % update_freq == 0:
331
- # ------------------------ calculate grad and hessian ------------------------ #
332
- if hessian_method == 'autograd':
333
- with torch.enable_grad():
334
- loss = var.loss = var.loss_approx = closure(False)
335
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
336
- g_list = [t[0] for t in g_list] # remove leading dim from loss
337
- var.grad = g_list
338
- H = flatten_jacobian(H_list)
339
-
340
- elif hessian_method in ('func', 'autograd.functional'):
341
- strat = 'forward-mode' if vectorize else 'reverse-mode'
342
- with torch.enable_grad():
343
- g_list = var.get_grad(retain_graph=True)
344
- H = hessian_mat(partial(closure, backward=False), params,
345
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
346
-
347
- else:
348
- raise ValueError(hessian_method)
349
-
350
- self.global_state["H"] = H
351
-
352
- # inverse free part
353
- if 'Y' not in self.global_state:
354
- num = H.T
355
- denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
356
- finfo = torch.finfo(H.dtype)
357
- Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
358
-
359
- else:
360
- Y = self.global_state['Y']
361
- I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
362
- I -= H @ Y
363
- Y = self.global_state['Y'] = Y @ I
364
-
365
-
366
- def apply(self, var):
367
- Y = self.global_state["Y"]
368
- params = var.params
369
-
370
- # -------------------------------- inner step -------------------------------- #
371
- update = var.get_update()
372
- if 'inner' in self.children:
373
- update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
374
-
375
- g = torch.cat([t.ravel() for t in update])
376
-
377
- # ----------------------------------- solve ---------------------------------- #
378
- var.update = vec_to_tensors(Y@g, params)
379
-
380
- return var
381
-
382
- def get_H(self,var):
383
- return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])