torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -8,16 +8,16 @@ import torch
8
8
  from ...core import Chainable, Module, apply_transform
9
9
  from ...utils import TensorList, vec_to_tensors
10
10
  from ...utils.derivatives import (
11
- hessian_list_to_mat,
11
+ flatten_jacobian,
12
12
  hessian_mat,
13
13
  hvp,
14
14
  hvp_fd_central,
15
15
  hvp_fd_forward,
16
16
  jacobian_and_hessian_wrt,
17
17
  )
18
+ from ...utils.linalg.linear_operator import DenseWithInverse, Dense
18
19
 
19
-
20
- def lu_solve(H: torch.Tensor, g: torch.Tensor):
20
+ def _lu_solve(H: torch.Tensor, g: torch.Tensor):
21
21
  try:
22
22
  x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
23
23
  if info == 0: return x
@@ -25,135 +25,359 @@ def lu_solve(H: torch.Tensor, g: torch.Tensor):
25
25
  except RuntimeError:
26
26
  return None
27
27
 
28
- def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
28
+ def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
29
29
  x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
30
30
  if info == 0:
31
31
  g.unsqueeze_(1)
32
32
  return torch.cholesky_solve(g, x)
33
33
  return None
34
34
 
35
- def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
35
+ def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
36
36
  return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
37
37
 
38
- def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
38
+ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
39
39
  try:
40
40
  L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
41
41
  if tfm is not None: L = tfm(L)
42
42
  if search_negative and L[0] < 0:
43
- d = Q[0]
44
- # use eigvec or -eigvec depending on if it points in same direction as gradient
45
- return g.dot(d).sign() * d
43
+ neg_mask = L < 0
44
+ Q_neg = Q[:, neg_mask] * L[neg_mask]
45
+ return (Q_neg * (g @ Q_neg).sign()).mean(1)
46
+
47
+ return Q @ ((Q.mH @ g) / L)
46
48
 
47
- L.reciprocal_()
48
- return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
49
49
  except torch.linalg.LinAlgError:
50
50
  return None
51
51
 
52
- def tikhonov_(H: torch.Tensor, reg: float):
53
- if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
54
- return H
55
52
 
56
- def eig_tikhonov_(H: torch.Tensor, reg: float):
57
- v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
58
- return tikhonov_(H, v)
59
53
 
60
54
 
61
55
  class Newton(Module):
62
- """Exact newton via autograd.
56
+ """Exact newton's method via autograd.
57
+
58
+ Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
59
+ The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
60
+ ``g`` can be output of another module, if it is specifed in ``inner`` argument.
61
+
62
+ Note:
63
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
64
+
65
+ Note:
66
+ This module requires the a closure passed to the optimizer step,
67
+ as it needs to re-evaluate the loss and gradients for calculating the hessian.
68
+ The closure must accept a ``backward`` argument (refer to documentation).
63
69
 
64
70
  Args:
65
- reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
66
- eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
71
+ damping (float, optional): tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.
67
72
  search_negative (bool, Optional):
68
- if True, whenever a negative eigenvalue is detected, the direction is taken along an eigenvector corresponding to a negative eigenvalue.
73
+ if True, whenever a negative eigenvalue is detected,
74
+ search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
75
+ use_lstsq (bool, Optional):
76
+ if True, least squares will be used to solve the linear system, this may generate reasonable directions
77
+ when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
78
+ If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
79
+ argument will be ignored.
69
80
  hessian_method (str):
70
81
  how to calculate hessian. Defaults to "autograd".
71
82
  vectorize (bool, optional):
72
83
  whether to enable vectorized hessian. Defaults to True.
73
- inner (Chainable | None, optional): inner modules. Defaults to None.
84
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
74
85
  H_tfm (Callable | None, optional):
75
86
  optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
76
87
 
77
- must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
78
- which must be True if transform inverted the hessian and False otherwise. Defaults to None.
79
- eigval_tfm (Callable | None, optional):
80
- optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
81
- If this is specified, eigendecomposition will be used to solve Hx = g.
88
+ must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
89
+ which must be True if transform inverted the hessian and False otherwise.
90
+
91
+ Or it returns a single tensor which is used as the update.
92
+
93
+ Defaults to None.
94
+ eigval_fn (Callable | None, optional):
95
+ optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
96
+ If this is specified, eigendecomposition will be used to invert the hessian.
97
+
98
+ # See also
99
+
100
+ * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products,
101
+ useful for large scale problems as it doesn't form the full hessian.
102
+ * ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
103
+ * ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
104
+ * ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
105
+
106
+ # Notes
107
+
108
+ ## Implementation details
109
+
110
+ ``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
111
+ The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
112
+ Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
113
+
114
+ Additionally, if ``eigval_fn`` is specified or ``search_negative`` is ``True``,
115
+ eigendecomposition of the hessian is computed, ``eigval_fn`` is applied to the eigenvalues,
116
+ and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues.
117
+ This is more generally more computationally expensive.
118
+
119
+ ## Handling non-convexity
120
+
121
+ Standard Newton's method does not handle non-convexity well without some modifications.
122
+ This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
123
+
124
+ The first modification to handle non-convexity is to modify the eignevalues to be positive,
125
+ for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
126
+
127
+ Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
128
+ This also requires an eigendecomposition.
129
+
130
+ The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
131
+ but that may be significantly less efficient.
132
+
133
+ # Examples:
134
+
135
+ Newton's method with backtracking line search
136
+
137
+ ```py
138
+ opt = tz.Modular(
139
+ model.parameters(),
140
+ tz.m.Newton(),
141
+ tz.m.Backtracking()
142
+ )
143
+ ```
144
+
145
+ Newton preconditioning applied to momentum
146
+
147
+ ```py
148
+ opt = tz.Modular(
149
+ model.parameters(),
150
+ tz.m.Newton(inner=tz.m.EMA(0.9)),
151
+ tz.m.LR(0.1)
152
+ )
153
+ ```
154
+
155
+ Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient,
156
+ but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
157
+
158
+ ```py
159
+ opt = tz.Modular(
160
+ model.parameters(),
161
+ tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
162
+ tz.m.Backtracking()
163
+ )
164
+ ```
82
165
 
83
166
  """
84
167
  def __init__(
85
168
  self,
86
- reg: float = 1e-6,
87
- eig_reg: bool = False,
169
+ damping: float = 0,
88
170
  search_negative: bool = False,
171
+ use_lstsq: bool = False,
172
+ update_freq: int = 1,
89
173
  hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
90
174
  vectorize: bool = True,
91
175
  inner: Chainable | None = None,
92
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
93
- eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
176
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
177
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
94
178
  ):
95
- defaults = dict(reg=reg, eig_reg=eig_reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative)
179
+ defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, search_negative=search_negative, update_freq=update_freq)
96
180
  super().__init__(defaults)
97
181
 
98
182
  if inner is not None:
99
183
  self.set_child('inner', inner)
100
184
 
101
185
  @torch.no_grad
102
- def step(self, var):
186
+ def update(self, var):
103
187
  params = TensorList(var.params)
104
188
  closure = var.closure
105
189
  if closure is None: raise RuntimeError('NewtonCG requires closure')
106
190
 
107
191
  settings = self.settings[params[0]]
108
- reg = settings['reg']
109
- eig_reg = settings['eig_reg']
110
- search_negative = settings['search_negative']
192
+ damping = settings['damping']
111
193
  hessian_method = settings['hessian_method']
112
194
  vectorize = settings['vectorize']
195
+ update_freq = settings['update_freq']
196
+
197
+ step = self.global_state.get('step', 0)
198
+ self.global_state['step'] = step + 1
199
+
200
+ g_list = var.grad
201
+ H = None
202
+ if step % update_freq == 0:
203
+ # ------------------------ calculate grad and hessian ------------------------ #
204
+ if hessian_method == 'autograd':
205
+ with torch.enable_grad():
206
+ loss = var.loss = var.loss_approx = closure(False)
207
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
208
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
209
+ var.grad = g_list
210
+ H = flatten_jacobian(H_list)
211
+
212
+ elif hessian_method in ('func', 'autograd.functional'):
213
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
214
+ with torch.enable_grad():
215
+ g_list = var.get_grad(retain_graph=True)
216
+ H = hessian_mat(partial(closure, backward=False), params,
217
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
218
+
219
+ else:
220
+ raise ValueError(hessian_method)
221
+
222
+ if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
223
+ self.global_state['H'] = H
224
+
225
+ @torch.no_grad
226
+ def apply(self, var):
227
+ H = self.global_state["H"]
228
+
229
+ params = var.params
230
+ settings = self.settings[params[0]]
231
+ search_negative = settings['search_negative']
113
232
  H_tfm = settings['H_tfm']
114
- eigval_tfm = settings['eigval_tfm']
115
-
116
- # ------------------------ calculate grad and hessian ------------------------ #
117
- if hessian_method == 'autograd':
118
- with torch.enable_grad():
119
- loss = var.loss = var.loss_approx = closure(False)
120
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
121
- g_list = [t[0] for t in g_list] # remove leading dim from loss
122
- var.grad = g_list
123
- H = hessian_list_to_mat(H_list)
124
-
125
- elif hessian_method in ('func', 'autograd.functional'):
126
- strat = 'forward-mode' if vectorize else 'reverse-mode'
127
- with torch.enable_grad():
128
- g_list = var.get_grad(retain_graph=True)
129
- H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
130
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
131
-
132
- else:
133
- raise ValueError(hessian_method)
233
+ eigval_fn = settings['eigval_fn']
234
+ use_lstsq = settings['use_lstsq']
134
235
 
135
236
  # -------------------------------- inner step -------------------------------- #
136
237
  update = var.get_update()
137
238
  if 'inner' in self.children:
138
- update = apply_transform(self.children['inner'], update, params=params, grads=list(g_list), var=var)
139
- g = torch.cat([t.ravel() for t in update])
239
+ update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
140
240
 
141
- # ------------------------------- regulazition ------------------------------- #
142
- if eig_reg: H = eig_tikhonov_(H, reg)
143
- else: H = tikhonov_(H, reg)
241
+ g = torch.cat([t.ravel() for t in update])
144
242
 
145
243
  # ----------------------------------- solve ---------------------------------- #
146
244
  update = None
147
245
  if H_tfm is not None:
148
- H, is_inv = H_tfm(H, g)
149
- if is_inv: update = H @ g
246
+ ret = H_tfm(H, g)
247
+
248
+ if isinstance(ret, torch.Tensor):
249
+ update = ret
250
+
251
+ else: # returns (H, is_inv)
252
+ H, is_inv = ret
253
+ if is_inv: update = H @ g
150
254
 
151
- if search_negative or (eigval_tfm is not None):
152
- update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
255
+ if search_negative or (eigval_fn is not None):
256
+ update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
153
257
 
154
- if update is None: update = cholesky_solve(H, g)
155
- if update is None: update = lu_solve(H, g)
156
- if update is None: update = least_squares_solve(H, g)
258
+ if update is None and use_lstsq: update = _least_squares_solve(H, g)
259
+ if update is None: update = _cholesky_solve(H, g)
260
+ if update is None: update = _lu_solve(H, g)
261
+ if update is None: update = _least_squares_solve(H, g)
157
262
 
158
263
  var.update = vec_to_tensors(update, params)
264
+
265
+ return var
266
+
267
+ def get_H(self,var):
268
+ H = self.global_state["H"]
269
+ settings = self.defaults
270
+ if settings['eigval_fn'] is not None:
271
+ try:
272
+ L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
273
+ L = settings['eigval_fn'](L)
274
+ H = Q @ L.diag_embed() @ Q.mH
275
+ H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
276
+ return DenseWithInverse(H, H_inv)
277
+
278
+ except torch.linalg.LinAlgError:
279
+ pass
280
+
281
+ return Dense(H)
282
+
283
+
284
+ class InverseFreeNewton(Module):
285
+ """Inverse-free newton's method
286
+
287
+ .. note::
288
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
289
+
290
+ .. note::
291
+ This module requires the a closure passed to the optimizer step,
292
+ as it needs to re-evaluate the loss and gradients for calculating the hessian.
293
+ The closure must accept a ``backward`` argument (refer to documentation).
294
+
295
+ .. warning::
296
+ this uses roughly O(N^2) memory.
297
+
298
+ Reference
299
+ Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
300
+ """
301
+ def __init__(
302
+ self,
303
+ update_freq: int = 1,
304
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
305
+ vectorize: bool = True,
306
+ inner: Chainable | None = None,
307
+ ):
308
+ defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
309
+ super().__init__(defaults)
310
+
311
+ if inner is not None:
312
+ self.set_child('inner', inner)
313
+
314
+ @torch.no_grad
315
+ def update(self, var):
316
+ params = TensorList(var.params)
317
+ closure = var.closure
318
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
319
+
320
+ settings = self.settings[params[0]]
321
+ hessian_method = settings['hessian_method']
322
+ vectorize = settings['vectorize']
323
+ update_freq = settings['update_freq']
324
+
325
+ step = self.global_state.get('step', 0)
326
+ self.global_state['step'] = step + 1
327
+
328
+ g_list = var.grad
329
+ Y = None
330
+ if step % update_freq == 0:
331
+ # ------------------------ calculate grad and hessian ------------------------ #
332
+ if hessian_method == 'autograd':
333
+ with torch.enable_grad():
334
+ loss = var.loss = var.loss_approx = closure(False)
335
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
336
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
337
+ var.grad = g_list
338
+ H = flatten_jacobian(H_list)
339
+
340
+ elif hessian_method in ('func', 'autograd.functional'):
341
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
342
+ with torch.enable_grad():
343
+ g_list = var.get_grad(retain_graph=True)
344
+ H = hessian_mat(partial(closure, backward=False), params,
345
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
346
+
347
+ else:
348
+ raise ValueError(hessian_method)
349
+
350
+ self.global_state["H"] = H
351
+
352
+ # inverse free part
353
+ if 'Y' not in self.global_state:
354
+ num = H.T
355
+ denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
356
+ finfo = torch.finfo(H.dtype)
357
+ Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
358
+
359
+ else:
360
+ Y = self.global_state['Y']
361
+ I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
362
+ I -= H @ Y
363
+ Y = self.global_state['Y'] = Y @ I
364
+
365
+
366
+ def apply(self, var):
367
+ Y = self.global_state["Y"]
368
+ params = var.params
369
+
370
+ # -------------------------------- inner step -------------------------------- #
371
+ update = var.get_update()
372
+ if 'inner' in self.children:
373
+ update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
374
+
375
+ g = torch.cat([t.ravel() for t in update])
376
+
377
+ # ----------------------------------- solve ---------------------------------- #
378
+ var.update = vec_to_tensors(Y@g, params)
379
+
159
380
  return var
381
+
382
+ def get_H(self,var):
383
+ return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])