torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,17 @@
1
- from collections.abc import Callable
2
- from contextlib import nullcontext
3
1
  from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Mapping
3
+ from typing import Any
4
+
4
5
  import numpy as np
5
6
  import torch
6
7
 
7
- from ...core import Chainable, Module, apply_transform, Var
8
- from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
9
- from ...utils.derivatives import (
10
- flatten_jacobian,
11
- jacobian_wrt,
12
- )
8
+ from ...core import Chainable, DerivativesMethod, Objective, Transform
9
+ from ...utils import TensorList, vec_to_tensors
10
+
13
11
 
14
- class HigherOrderMethodBase(Module, ABC):
15
- def __init__(self, defaults: dict | None = None, vectorize: bool = True):
16
- self._vectorize = vectorize
12
+ class HigherOrderMethodBase(Transform, ABC):
13
+ def __init__(self, defaults: dict | None = None, derivatives_method: DerivativesMethod = 'batched_autograd'):
14
+ self._derivatives_method: DerivativesMethod = derivatives_method
17
15
  super().__init__(defaults)
18
16
 
19
17
  @abstractmethod
@@ -21,61 +19,27 @@ class HigherOrderMethodBase(Module, ABC):
21
19
  self,
22
20
  x: torch.Tensor,
23
21
  evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
24
- var: Var,
22
+ objective: Objective,
23
+ setting: Mapping[str, Any],
25
24
  ) -> torch.Tensor:
26
25
  """"""
27
26
 
28
27
  @torch.no_grad
29
- def step(self, var):
30
- params = TensorList(var.params)
31
- x0 = params.clone()
32
- closure = var.closure
28
+ def apply_states(self, objective, states, settings):
29
+ params = TensorList(objective.params)
30
+
31
+ closure = objective.closure
33
32
  if closure is None: raise RuntimeError('MultipointNewton requires closure')
34
- vectorize = self._vectorize
33
+ derivatives_method = self._derivatives_method
35
34
 
36
35
  def evaluate(x, order) -> tuple[torch.Tensor, ...]:
37
36
  """order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
38
- params.from_vec_(x)
39
-
40
- if order == 0:
41
- loss = closure(False)
42
- params.copy_(x0)
43
- return (loss, )
44
-
45
- if order == 1:
46
- with torch.enable_grad():
47
- loss = closure()
48
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
49
- params.copy_(x0)
50
- return loss, torch.cat([g.ravel() for g in grad])
51
-
52
- with torch.enable_grad():
53
- loss = var.loss = var.loss_approx = closure(False)
54
-
55
- g_list = torch.autograd.grad(loss, params, create_graph=True)
56
- var.grad = list(g_list)
57
-
58
- g = torch.cat([t.ravel() for t in g_list])
59
- n = g.numel()
60
- ret = [loss, g]
61
- T = g # current derivatives tensor
62
-
63
- # get all derivative up to order
64
- for o in range(2, order + 1):
65
- is_last = o == order
66
- T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
67
- with torch.no_grad() if is_last else nullcontext():
68
- # the shape is (ndim, ) * order
69
- T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
70
- ret.append(T)
71
-
72
- params.copy_(x0)
73
- return tuple(ret)
37
+ return objective.derivatives_at(x, order, method=derivatives_method)
74
38
 
75
39
  x = torch.cat([p.ravel() for p in params])
76
- dir = self.one_iteration(x, evaluate, var)
77
- var.update = vec_to_tensors(dir, var.params)
78
- return var
40
+ dir = self.one_iteration(x, evaluate, objective, settings[0])
41
+ objective.updates = vec_to_tensors(dir, objective.params)
42
+ return objective
79
43
 
80
44
  def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
81
45
  if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
@@ -106,16 +70,15 @@ class SixthOrder3P(HigherOrderMethodBase):
106
70
 
107
71
  Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
108
72
  """
109
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
73
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
110
74
  defaults=dict(lstsq=lstsq)
111
- super().__init__(defaults=defaults, vectorize=vectorize)
75
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
112
76
 
113
- def one_iteration(self, x, evaluate, var):
114
- settings = self.defaults
115
- lstsq = settings['lstsq']
77
+ @torch.no_grad
78
+ def one_iteration(self, x, evaluate, objective, setting):
116
79
  def f(x): return evaluate(x, 1)[1]
117
80
  def f_j(x): return evaluate(x, 2)[1:]
118
- x_star = sixth_order_3p(x, f, f_j, lstsq)
81
+ x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
119
82
  return x - x_star
120
83
 
121
84
  # I don't think it works (I tested root finding with this and it goes all over the place)
@@ -173,15 +136,14 @@ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
173
136
 
174
137
  class SixthOrder5P(HigherOrderMethodBase):
175
138
  """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
176
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
139
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
177
140
  defaults=dict(lstsq=lstsq)
178
- super().__init__(defaults=defaults, vectorize=vectorize)
141
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
179
142
 
180
- def one_iteration(self, x, evaluate, var):
181
- settings = self.defaults
182
- lstsq = settings['lstsq']
143
+ @torch.no_grad
144
+ def one_iteration(self, x, evaluate, objective, setting):
183
145
  def f_j(x): return evaluate(x, 2)[1:]
184
- x_star = sixth_order_5p(x, f_j, lstsq)
146
+ x_star = sixth_order_5p(x, f_j, setting['lstsq'])
185
147
  return x - x_star
186
148
 
187
149
  # 2f 1J 2 solves
@@ -196,16 +158,15 @@ class TwoPointNewton(HigherOrderMethodBase):
196
158
  """two-point Newton method with frozen derivative with third order convergence.
197
159
 
198
160
  Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
199
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
161
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
200
162
  defaults=dict(lstsq=lstsq)
201
- super().__init__(defaults=defaults, vectorize=vectorize)
163
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
202
164
 
203
- def one_iteration(self, x, evaluate, var):
204
- settings = self.defaults
205
- lstsq = settings['lstsq']
165
+ @torch.no_grad
166
+ def one_iteration(self, x, evaluate, objective, setting):
206
167
  def f(x): return evaluate(x, 1)[1]
207
168
  def f_j(x): return evaluate(x, 2)[1:]
208
- x_star = two_point_newton(x, f, f_j, lstsq)
169
+ x_star = two_point_newton(x, f, f_j, setting['lstsq'])
209
170
  return x - x_star
210
171
 
211
172
  #3f 2J 1inv
@@ -224,15 +185,14 @@ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
224
185
 
225
186
  class SixthOrder3PM2(HigherOrderMethodBase):
226
187
  """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
227
- def __init__(self, lstsq: bool=False, vectorize: bool = True):
188
+ def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
228
189
  defaults=dict(lstsq=lstsq)
229
- super().__init__(defaults=defaults, vectorize=vectorize)
190
+ super().__init__(defaults=defaults, derivatives_method=derivatives_method)
230
191
 
231
- def one_iteration(self, x, evaluate, var):
232
- settings = self.defaults
233
- lstsq = settings['lstsq']
192
+ @torch.no_grad
193
+ def one_iteration(self, x, evaluate, objective, setting):
234
194
  def f_j(x): return evaluate(x, 2)[1:]
235
195
  def f(x): return evaluate(x, 1)[1]
236
- x_star = sixth_order_3pm2(x, f, f_j, lstsq)
196
+ x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
237
197
  return x - x_star
238
198
 
@@ -1,21 +1,12 @@
1
- import warnings
2
1
  from collections.abc import Callable
3
- from functools import partial
4
2
  from typing import Literal
5
3
 
6
4
  import torch
7
5
 
8
- from ...core import Chainable, Module, apply_transform, Var
9
- from ...utils import TensorList, vec_to_tensors
10
- from ...utils.derivatives import (
11
- flatten_jacobian,
12
- hessian_mat,
13
- hvp,
14
- hvp_fd_central,
15
- hvp_fd_forward,
16
- jacobian_and_hessian_wrt,
17
- )
18
- from ...utils.linalg.linear_operator import DenseWithInverse, Dense
6
+ from ...core import Chainable, Transform, Objective, HessianMethod, Module
7
+ from ...utils import vec_to_tensors
8
+ from ...linalg.linear_operator import Dense, DenseWithInverse
9
+
19
10
 
20
11
  def _lu_solve(H: torch.Tensor, g: torch.Tensor):
21
12
  try:
@@ -26,10 +17,9 @@ def _lu_solve(H: torch.Tensor, g: torch.Tensor):
26
17
  return None
27
18
 
28
19
  def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
29
- x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
20
+ L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
30
21
  if info == 0:
31
- g.unsqueeze_(1)
32
- return torch.cholesky_solve(g, x)
22
+ return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
33
23
  return None
34
24
 
35
25
  def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
@@ -49,49 +39,14 @@ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_n
49
39
  except torch.linalg.LinAlgError:
50
40
  return None
51
41
 
52
-
53
- def _get_loss_grad_and_hessian(var: Var, hessian_method:str, vectorize:bool):
54
- """returns (loss, g_list, H). Also sets var.loss and var.grad.
55
- If hessian_method isn't 'autograd', loss is not set and returned as None"""
56
- closure = var.closure
57
- if closure is None:
58
- raise RuntimeError("Second order methods requires a closure to be provided to the `step` method.")
59
-
60
- params = var.params
61
-
62
- # ------------------------ calculate grad and hessian ------------------------ #
63
- loss = None
64
- if hessian_method == 'autograd':
65
- with torch.enable_grad():
66
- loss = var.loss = var.loss_approx = closure(False)
67
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
68
- g_list = [t[0] for t in g_list] # remove leading dim from loss
69
- var.grad = g_list
70
- H = flatten_jacobian(H_list)
71
-
72
- elif hessian_method in ('func', 'autograd.functional'):
73
- strat = 'forward-mode' if vectorize else 'reverse-mode'
74
- with torch.enable_grad():
75
- g_list = var.get_grad(retain_graph=True)
76
- H = hessian_mat(partial(closure, backward=False), params,
77
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
78
-
79
- else:
80
- raise ValueError(hessian_method)
81
-
82
- return loss, g_list, H
83
-
84
- def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None) -> torch.Tensor:
85
- """returns the update tensor, then do vec_to_tensor(update, params)"""
86
- params = var.params
87
-
88
- if damping != 0:
89
- H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
90
-
42
+ def _newton_step(objective: Objective, H: torch.Tensor, damping:float, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None, no_inner: Module | None = None) -> torch.Tensor:
43
+ """INNER SHOULD BE NONE IN MOST CASES! Because Transform already has inner.
44
+ Returns the update tensor, then do vec_to_tensor(update, params)"""
91
45
  # -------------------------------- inner step -------------------------------- #
92
- update = var.get_update()
93
- if inner is not None:
94
- update = apply_transform(inner, update, params=params, grads=var.grad, loss=var.loss, var=var)
46
+ if no_inner is not None:
47
+ objective = no_inner.step(objective)
48
+
49
+ update = objective.get_updates()
95
50
 
96
51
  g = torch.cat([t.ravel() for t in update])
97
52
  if g_proj is not None: g = g_proj(g)
@@ -99,6 +54,9 @@ def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None,
99
54
  # ----------------------------------- solve ---------------------------------- #
100
55
  update = None
101
56
 
57
+ if damping != 0:
58
+ H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
59
+
102
60
  if H_tfm is not None:
103
61
  ret = H_tfm(H, g)
104
62
 
@@ -133,7 +91,7 @@ def _get_H(H: torch.Tensor, eigval_fn):
133
91
 
134
92
  return Dense(H)
135
93
 
136
- class Newton(Module):
94
+ class Newton(Transform):
137
95
  """Exact newton's method via autograd.
138
96
 
139
97
  Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
@@ -141,7 +99,7 @@ class Newton(Module):
141
99
  ``g`` can be output of another module, if it is specifed in ``inner`` argument.
142
100
 
143
101
  Note:
144
- In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
102
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
145
103
 
146
104
  Note:
147
105
  This module requires the a closure passed to the optimizer step,
@@ -158,10 +116,6 @@ class Newton(Module):
158
116
  when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
159
117
  If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
160
118
  argument will be ignored.
161
- hessian_method (str):
162
- how to calculate hessian. Defaults to "autograd".
163
- vectorize (bool, optional):
164
- whether to enable vectorized hessian. Defaults to True.
165
119
  H_tfm (Callable | None, optional):
166
120
  optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
167
121
 
@@ -174,6 +128,21 @@ class Newton(Module):
174
128
  eigval_fn (Callable | None, optional):
175
129
  optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
176
130
  If this is specified, eigendecomposition will be used to invert the hessian.
131
+ hessian_method (str):
132
+ Determines how hessian is computed.
133
+
134
+ - ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
135
+ - ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
136
+ - ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
137
+ - ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
138
+ - ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
139
+ - ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
140
+ - ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
141
+ - ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
142
+
143
+ Defaults to ``"batched_autograd"``.
144
+ h (float, optional):
145
+ finite difference step size for "fd_forward" and "fd_central".
177
146
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
178
147
 
179
148
  # See also
@@ -249,45 +218,43 @@ class Newton(Module):
249
218
  damping: float = 0,
250
219
  use_lstsq: bool = False,
251
220
  update_freq: int = 1,
252
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
253
- vectorize: bool = True,
254
221
  H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
255
222
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
223
+ hessian_method: HessianMethod = "batched_autograd",
224
+ h: float = 1e-3,
256
225
  inner: Chainable | None = None,
257
226
  ):
258
- defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
259
- super().__init__(defaults)
260
-
261
- if inner is not None:
262
- self.set_child('inner', inner)
227
+ defaults = locals().copy()
228
+ del defaults['self'], defaults['update_freq'], defaults["inner"]
229
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
263
230
 
264
231
  @torch.no_grad
265
- def update(self, var):
266
- step = self.global_state.get('step', 0)
267
- self.global_state['step'] = step + 1
232
+ def update_states(self, objective, states, settings):
233
+ fs = settings[0]
268
234
 
269
- if step % self.defaults['update_freq'] == 0:
270
- loss, g_list, self.global_state['H'] = _get_loss_grad_and_hessian(
271
- var, self.defaults['hessian_method'], self.defaults['vectorize']
272
- )
235
+ _, _, self.global_state['H'] = objective.hessian(
236
+ hessian_method=fs['hessian_method'],
237
+ h=fs['h'],
238
+ at_x0=True
239
+ )
273
240
 
274
241
  @torch.no_grad
275
- def apply(self, var):
276
- params = var.params
242
+ def apply_states(self, objective, states, settings):
243
+ params = objective.params
244
+ fs = settings[0]
245
+
277
246
  update = _newton_step(
278
- var=var,
247
+ objective=objective,
279
248
  H = self.global_state["H"],
280
- damping=self.defaults["damping"],
281
- inner=self.children.get("inner", None),
282
- H_tfm=self.defaults["H_tfm"],
283
- eigval_fn=self.defaults["eigval_fn"],
284
- use_lstsq=self.defaults["use_lstsq"],
249
+ damping = fs["damping"],
250
+ H_tfm = fs["H_tfm"],
251
+ eigval_fn = fs["eigval_fn"],
252
+ use_lstsq = fs["use_lstsq"],
285
253
  )
286
254
 
287
- var.update = vec_to_tensors(update, params)
288
-
289
- return var
255
+ objective.updates = vec_to_tensors(update, params)
256
+ return objective
290
257
 
291
- def get_H(self,var=...):
258
+ def get_H(self,objective=...):
292
259
  return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
293
260