torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -8,16 +8,16 @@ import torch
8
8
  from ...core import Chainable, Module, apply_transform
9
9
  from ...utils import TensorList, vec_to_tensors
10
10
  from ...utils.derivatives import (
11
- hessian_list_to_mat,
11
+ flatten_jacobian,
12
12
  hessian_mat,
13
13
  hvp,
14
14
  hvp_fd_central,
15
15
  hvp_fd_forward,
16
16
  jacobian_and_hessian_wrt,
17
17
  )
18
+ from ...utils.linalg.linear_operator import DenseWithInverse, Dense
18
19
 
19
-
20
- def lu_solve(H: torch.Tensor, g: torch.Tensor):
20
+ def _lu_solve(H: torch.Tensor, g: torch.Tensor):
21
21
  try:
22
22
  x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
23
23
  if info == 0: return x
@@ -25,55 +25,58 @@ def lu_solve(H: torch.Tensor, g: torch.Tensor):
25
25
  except RuntimeError:
26
26
  return None
27
27
 
28
- def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
28
+ def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
29
29
  x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
30
30
  if info == 0:
31
31
  g.unsqueeze_(1)
32
32
  return torch.cholesky_solve(g, x)
33
33
  return None
34
34
 
35
- def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
35
+ def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
36
36
  return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
37
37
 
38
- def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
38
+ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
39
39
  try:
40
40
  L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
41
41
  if tfm is not None: L = tfm(L)
42
42
  if search_negative and L[0] < 0:
43
- d = Q[0]
44
- # use eigvec or -eigvec depending on if it points in same direction as gradient
45
- return g.dot(d).sign() * d
43
+ neg_mask = L < 0
44
+ Q_neg = Q[:, neg_mask] * L[neg_mask]
45
+ return (Q_neg * (g @ Q_neg).sign()).mean(1)
46
46
 
47
47
  return Q @ ((Q.mH @ g) / L)
48
48
 
49
49
  except torch.linalg.LinAlgError:
50
50
  return None
51
51
 
52
- def tikhonov_(H: torch.Tensor, reg: float):
53
- if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
54
- return H
52
+
55
53
 
56
54
 
57
55
  class Newton(Module):
58
56
  """Exact newton's method via autograd.
59
57
 
60
- .. note::
58
+ Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
59
+ The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
60
+ ``g`` can be output of another module, if it is specifed in ``inner`` argument.
61
+
62
+ Note:
61
63
  In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
62
64
 
63
- .. note::
65
+ Note:
64
66
  This module requires the a closure passed to the optimizer step,
65
67
  as it needs to re-evaluate the loss and gradients for calculating the hessian.
66
68
  The closure must accept a ``backward`` argument (refer to documentation).
67
69
 
68
- .. warning::
69
- this uses roughly O(N^2) memory.
70
-
71
-
72
70
  Args:
73
- reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
71
+ damping (float, optional): tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.
74
72
  search_negative (bool, Optional):
75
73
  if True, whenever a negative eigenvalue is detected,
76
- search direction is proposed along an eigenvector corresponding to a negative eigenvalue.
74
+ search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
75
+ use_lstsq (bool, Optional):
76
+ if True, least squares will be used to solve the linear system, this may generate reasonable directions
77
+ when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
78
+ If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
79
+ argument will be ignored.
77
80
  hessian_method (str):
78
81
  how to calculate hessian. Defaults to "autograd".
79
82
  vectorize (bool, optional):
@@ -88,92 +91,107 @@ class Newton(Module):
88
91
  Or it returns a single tensor which is used as the update.
89
92
 
90
93
  Defaults to None.
91
- eigval_tfm (Callable | None, optional):
92
- optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
94
+ eigval_fn (Callable | None, optional):
95
+ optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
93
96
  If this is specified, eigendecomposition will be used to invert the hessian.
94
97
 
95
- Examples:
96
- Newton's method with backtracking line search
98
+ # See also
99
+
100
+ * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products,
101
+ useful for large scale problems as it doesn't form the full hessian.
102
+ * ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
103
+ * ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
104
+ * ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
97
105
 
98
- .. code-block:: python
106
+ # Notes
99
107
 
100
- opt = tz.Modular(
101
- model.parameters(),
102
- tz.m.Newton(),
103
- tz.m.Backtracking()
104
- )
108
+ ## Implementation details
105
109
 
106
- Newton's method modified for non-convex functions by taking matrix absolute value of the hessian
110
+ ``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
111
+ The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
112
+ Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
107
113
 
108
- .. code-block:: python
114
+ Additionally, if ``eigval_fn`` is specified or ``search_negative`` is ``True``,
115
+ eigendecomposition of the hessian is computed, ``eigval_fn`` is applied to the eigenvalues,
116
+ and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues.
117
+ This is more generally more computationally expensive.
109
118
 
110
- opt = tz.Modular(
111
- model.parameters(),
112
- tz.m.Newton(eigval_tfm=lambda x: torch.abs(x).clip(min=0.1)),
113
- tz.m.Backtracking()
114
- )
119
+ ## Handling non-convexity
115
120
 
116
- Newton's method modified for non-convex functions by searching along negative curvature directions
121
+ Standard Newton's method does not handle non-convexity well without some modifications.
122
+ This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
117
123
 
118
- .. code-block:: python
124
+ The first modification to handle non-convexity is to modify the eignevalues to be positive,
125
+ for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
119
126
 
120
- opt = tz.Modular(
121
- model.parameters(),
122
- tz.m.Newton(search_negative=True),
123
- tz.m.Backtracking()
124
- )
127
+ Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
128
+ This also requires an eigendecomposition.
125
129
 
126
- Newton preconditioning applied to momentum
130
+ The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
131
+ but that may be significantly less efficient.
127
132
 
128
- .. code-block:: python
133
+ # Examples:
129
134
 
130
- opt = tz.Modular(
131
- model.parameters(),
132
- tz.m.Newton(inner=tz.m.EMA(0.9)),
133
- tz.m.LR(0.1)
134
- )
135
+ Newton's method with backtracking line search
135
136
 
136
- Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient, but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
137
+ ```py
138
+ opt = tz.Modular(
139
+ model.parameters(),
140
+ tz.m.Newton(),
141
+ tz.m.Backtracking()
142
+ )
143
+ ```
137
144
 
138
- .. code-block:: python
145
+ Newton preconditioning applied to momentum
139
146
 
140
- opt = tz.Modular(
141
- model.parameters(),
142
- tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
143
- tz.m.Backtracking()
144
- )
147
+ ```py
148
+ opt = tz.Modular(
149
+ model.parameters(),
150
+ tz.m.Newton(inner=tz.m.EMA(0.9)),
151
+ tz.m.LR(0.1)
152
+ )
153
+ ```
154
+
155
+ Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient,
156
+ but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
157
+
158
+ ```py
159
+ opt = tz.Modular(
160
+ model.parameters(),
161
+ tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
162
+ tz.m.Backtracking()
163
+ )
164
+ ```
145
165
 
146
166
  """
147
167
  def __init__(
148
168
  self,
149
- reg: float = 1e-6,
169
+ damping: float = 0,
150
170
  search_negative: bool = False,
171
+ use_lstsq: bool = False,
151
172
  update_freq: int = 1,
152
173
  hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
153
174
  vectorize: bool = True,
154
175
  inner: Chainable | None = None,
155
176
  H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
156
- eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
177
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
157
178
  ):
158
- defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative, update_freq=update_freq)
179
+ defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, search_negative=search_negative, update_freq=update_freq)
159
180
  super().__init__(defaults)
160
181
 
161
182
  if inner is not None:
162
183
  self.set_child('inner', inner)
163
184
 
164
185
  @torch.no_grad
165
- def step(self, var):
186
+ def update(self, var):
166
187
  params = TensorList(var.params)
167
188
  closure = var.closure
168
189
  if closure is None: raise RuntimeError('NewtonCG requires closure')
169
190
 
170
191
  settings = self.settings[params[0]]
171
- reg = settings['reg']
172
- search_negative = settings['search_negative']
192
+ damping = settings['damping']
173
193
  hessian_method = settings['hessian_method']
174
194
  vectorize = settings['vectorize']
175
- H_tfm = settings['H_tfm']
176
- eigval_tfm = settings['eigval_tfm']
177
195
  update_freq = settings['update_freq']
178
196
 
179
197
  step = self.global_state.get('step', 0)
@@ -189,7 +207,7 @@ class Newton(Module):
189
207
  g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
190
208
  g_list = [t[0] for t in g_list] # remove leading dim from loss
191
209
  var.grad = g_list
192
- H = hessian_list_to_mat(H_list)
210
+ H = flatten_jacobian(H_list)
193
211
 
194
212
  elif hessian_method in ('func', 'autograd.functional'):
195
213
  strat = 'forward-mode' if vectorize else 'reverse-mode'
@@ -201,23 +219,27 @@ class Newton(Module):
201
219
  else:
202
220
  raise ValueError(hessian_method)
203
221
 
204
- H = tikhonov_(H, reg)
205
- if update_freq != 1:
206
- self.global_state['H'] = H
222
+ if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
223
+ self.global_state['H'] = H
207
224
 
208
- if H is None:
209
- H = self.global_state["H"]
225
+ @torch.no_grad
226
+ def apply(self, var):
227
+ H = self.global_state["H"]
210
228
 
211
- # var.storage['hessian'] = H
229
+ params = var.params
230
+ settings = self.settings[params[0]]
231
+ search_negative = settings['search_negative']
232
+ H_tfm = settings['H_tfm']
233
+ eigval_fn = settings['eigval_fn']
234
+ use_lstsq = settings['use_lstsq']
212
235
 
213
236
  # -------------------------------- inner step -------------------------------- #
214
237
  update = var.get_update()
215
238
  if 'inner' in self.children:
216
- update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
239
+ update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
217
240
 
218
241
  g = torch.cat([t.ravel() for t in update])
219
242
 
220
-
221
243
  # ----------------------------------- solve ---------------------------------- #
222
244
  update = None
223
245
  if H_tfm is not None:
@@ -230,17 +252,35 @@ class Newton(Module):
230
252
  H, is_inv = ret
231
253
  if is_inv: update = H @ g
232
254
 
233
- if search_negative or (eigval_tfm is not None):
234
- update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
255
+ if search_negative or (eigval_fn is not None):
256
+ update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
235
257
 
236
- if update is None: update = cholesky_solve(H, g)
237
- if update is None: update = lu_solve(H, g)
238
- if update is None: update = least_squares_solve(H, g)
258
+ if update is None and use_lstsq: update = _least_squares_solve(H, g)
259
+ if update is None: update = _cholesky_solve(H, g)
260
+ if update is None: update = _lu_solve(H, g)
261
+ if update is None: update = _least_squares_solve(H, g)
239
262
 
240
263
  var.update = vec_to_tensors(update, params)
241
264
 
242
265
  return var
243
266
 
267
+ def get_H(self,var):
268
+ H = self.global_state["H"]
269
+ settings = self.defaults
270
+ if settings['eigval_fn'] is not None:
271
+ try:
272
+ L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
273
+ L = settings['eigval_fn'](L)
274
+ H = Q @ L.diag_embed() @ Q.mH
275
+ H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
276
+ return DenseWithInverse(H, H_inv)
277
+
278
+ except torch.linalg.LinAlgError:
279
+ pass
280
+
281
+ return Dense(H)
282
+
283
+
244
284
  class InverseFreeNewton(Module):
245
285
  """Inverse-free newton's method
246
286
 
@@ -272,7 +312,7 @@ class InverseFreeNewton(Module):
272
312
  self.set_child('inner', inner)
273
313
 
274
314
  @torch.no_grad
275
- def step(self, var):
315
+ def update(self, var):
276
316
  params = TensorList(var.params)
277
317
  closure = var.closure
278
318
  if closure is None: raise RuntimeError('NewtonCG requires closure')
@@ -295,7 +335,7 @@ class InverseFreeNewton(Module):
295
335
  g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
296
336
  g_list = [t[0] for t in g_list] # remove leading dim from loss
297
337
  var.grad = g_list
298
- H = hessian_list_to_mat(H_list)
338
+ H = flatten_jacobian(H_list)
299
339
 
300
340
  elif hessian_method in ('func', 'autograd.functional'):
301
341
  strat = 'forward-mode' if vectorize else 'reverse-mode'
@@ -307,12 +347,14 @@ class InverseFreeNewton(Module):
307
347
  else:
308
348
  raise ValueError(hessian_method)
309
349
 
350
+ self.global_state["H"] = H
351
+
310
352
  # inverse free part
311
353
  if 'Y' not in self.global_state:
312
354
  num = H.T
313
355
  denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
314
- eps = torch.finfo(H.dtype).eps
315
- Y = self.global_state['Y'] = num.div_(denom.clip(min=eps, max=1/eps))
356
+ finfo = torch.finfo(H.dtype)
357
+ Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
316
358
 
317
359
  else:
318
360
  Y = self.global_state['Y']
@@ -320,19 +362,22 @@ class InverseFreeNewton(Module):
320
362
  I -= H @ Y
321
363
  Y = self.global_state['Y'] = Y @ I
322
364
 
323
- if Y is None:
324
- Y = self.global_state["Y"]
325
365
 
366
+ def apply(self, var):
367
+ Y = self.global_state["Y"]
368
+ params = var.params
326
369
 
327
370
  # -------------------------------- inner step -------------------------------- #
328
371
  update = var.get_update()
329
372
  if 'inner' in self.children:
330
- update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
373
+ update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
331
374
 
332
375
  g = torch.cat([t.ravel() for t in update])
333
376
 
334
-
335
377
  # ----------------------------------- solve ---------------------------------- #
336
378
  var.update = vec_to_tensors(Y@g, params)
337
379
 
338
380
  return var
381
+
382
+ def get_H(self,var):
383
+ return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])