torchzero 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. tests/test_opts.py +4 -3
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +8 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/experimental/__init__.py +1 -0
  16. torchzero/modules/experimental/newtonnewton.py +5 -5
  17. torchzero/modules/experimental/spsa1.py +2 -2
  18. torchzero/modules/functional.py +7 -0
  19. torchzero/modules/line_search/__init__.py +1 -1
  20. torchzero/modules/line_search/_polyinterp.py +3 -1
  21. torchzero/modules/line_search/adaptive.py +3 -3
  22. torchzero/modules/line_search/backtracking.py +1 -1
  23. torchzero/modules/line_search/interpolation.py +160 -0
  24. torchzero/modules/line_search/line_search.py +11 -20
  25. torchzero/modules/line_search/strong_wolfe.py +3 -3
  26. torchzero/modules/misc/misc.py +2 -2
  27. torchzero/modules/misc/multistep.py +13 -13
  28. torchzero/modules/quasi_newton/__init__.py +2 -0
  29. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  30. torchzero/modules/quasi_newton/sg2.py +292 -0
  31. torchzero/modules/second_order/__init__.py +6 -3
  32. torchzero/modules/second_order/ifn.py +89 -0
  33. torchzero/modules/second_order/inm.py +105 -0
  34. torchzero/modules/second_order/newton.py +103 -193
  35. torchzero/modules/second_order/nystrom.py +1 -1
  36. torchzero/modules/second_order/rsn.py +227 -0
  37. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  38. torchzero/utils/derivatives.py +19 -19
  39. torchzero/utils/linalg/linear_operator.py +50 -2
  40. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  41. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/RECORD +44 -36
  42. torchzero/modules/higher_order/__init__.py +0 -1
  43. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  44. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  45. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,89 @@
1
+ import warnings
2
+ from collections.abc import Callable
3
+ from functools import partial
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, apply_transform, Var
9
+ from ...utils import TensorList, vec_to_tensors
10
+ from ...utils.linalg.linear_operator import DenseWithInverse, Dense
11
+ from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
12
+
13
+
14
+ class InverseFreeNewton(Module):
15
+ """Inverse-free newton's method
16
+
17
+ .. note::
18
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
19
+
20
+ .. note::
21
+ This module requires the a closure passed to the optimizer step,
22
+ as it needs to re-evaluate the loss and gradients for calculating the hessian.
23
+ The closure must accept a ``backward`` argument (refer to documentation).
24
+
25
+ .. warning::
26
+ this uses roughly O(N^2) memory.
27
+
28
+ Reference
29
+ [Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
30
+ """
31
+ def __init__(
32
+ self,
33
+ update_freq: int = 1,
34
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
35
+ vectorize: bool = True,
36
+ inner: Chainable | None = None,
37
+ ):
38
+ defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
39
+ super().__init__(defaults)
40
+
41
+ if inner is not None:
42
+ self.set_child('inner', inner)
43
+
44
+ @torch.no_grad
45
+ def update(self, var):
46
+ update_freq = self.defaults['update_freq']
47
+
48
+ step = self.global_state.get('step', 0)
49
+ self.global_state['step'] = step + 1
50
+
51
+ if step % update_freq == 0:
52
+ loss, g_list, H = _get_loss_grad_and_hessian(
53
+ var, self.defaults['hessian_method'], self.defaults['vectorize']
54
+ )
55
+ self.global_state["H"] = H
56
+
57
+ # inverse free part
58
+ if 'Y' not in self.global_state:
59
+ num = H.T
60
+ denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
61
+
62
+ finfo = torch.finfo(H.dtype)
63
+ self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
64
+
65
+ else:
66
+ Y = self.global_state['Y']
67
+ I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
68
+ I2 -= H @ Y
69
+ self.global_state['Y'] = Y @ I2
70
+
71
+
72
+ def apply(self, var):
73
+ Y = self.global_state["Y"]
74
+ params = var.params
75
+
76
+ # -------------------------------- inner step -------------------------------- #
77
+ update = var.get_update()
78
+ if 'inner' in self.children:
79
+ update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
80
+
81
+ g = torch.cat([t.ravel() for t in update])
82
+
83
+ # ----------------------------------- solve ---------------------------------- #
84
+ var.update = vec_to_tensors(Y@g, params)
85
+
86
+ return var
87
+
88
+ def get_H(self,var):
89
+ return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
@@ -0,0 +1,105 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module
7
+ from ...utils import TensorList, vec_to_tensors
8
+ from ..functional import safe_clip
9
+ from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
10
+
11
+ @torch.no_grad
12
+ def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
13
+
14
+ yy = safe_clip(y.dot(y))
15
+ ss = safe_clip(s.dot(s))
16
+
17
+ term1 = y.dot(y - J@s) / yy
18
+ FbT = f.outer(s).mul_(term1 / ss)
19
+
20
+ P = FbT.add_(J)
21
+ return P
22
+
23
+ def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
24
+ if fn is None: return J
25
+ L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
26
+ return (Q * L.unsqueeze(-2)) @ Q.mH
27
+
28
+ class INM(Module):
29
+ """Improved Newton's Method (INM).
30
+
31
+ Reference:
32
+ [Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.](https://d-nb.info/1112813721/34)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ damping: float = 0,
38
+ use_lstsq: bool = False,
39
+ update_freq: int = 1,
40
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
41
+ vectorize: bool = True,
42
+ inner: Chainable | None = None,
43
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
44
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
45
+ ):
46
+ defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
47
+ super().__init__(defaults)
48
+
49
+ if inner is not None:
50
+ self.set_child("inner", inner)
51
+
52
+ @torch.no_grad
53
+ def update(self, var):
54
+ update_freq = self.defaults['update_freq']
55
+
56
+ step = self.global_state.get('step', 0)
57
+ self.global_state['step'] = step + 1
58
+
59
+ if step % update_freq == 0:
60
+ _, f_list, J = _get_loss_grad_and_hessian(
61
+ var, self.defaults['hessian_method'], self.defaults['vectorize']
62
+ )
63
+
64
+ f = torch.cat([t.ravel() for t in f_list])
65
+ J = _eigval_fn(J, self.defaults["eigval_fn"])
66
+
67
+ x_list = TensorList(var.params)
68
+ f_list = TensorList(var.get_grad())
69
+ x_prev, f_prev = self.get_state(var.params, "x_prev", "f_prev", cls=TensorList)
70
+
71
+ # initialize on 1st step, do Newton step
72
+ if step == 0:
73
+ x_prev.copy_(x_list)
74
+ f_prev.copy_(f_list)
75
+ self.global_state["P"] = J
76
+ return
77
+
78
+ # INM update
79
+ s_list = x_list - x_prev
80
+ y_list = f_list - f_prev
81
+ x_prev.copy_(x_list)
82
+ f_prev.copy_(f_list)
83
+
84
+ self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
85
+
86
+
87
+ @torch.no_grad
88
+ def apply(self, var):
89
+ params = var.params
90
+ update = _newton_step(
91
+ var=var,
92
+ H = self.global_state["P"],
93
+ damping=self.defaults["damping"],
94
+ inner=self.children.get("inner", None),
95
+ H_tfm=self.defaults["H_tfm"],
96
+ eigval_fn=None, # it is applied in `update`
97
+ use_lstsq=self.defaults["use_lstsq"],
98
+ )
99
+
100
+ var.update = vec_to_tensors(update, params)
101
+
102
+ return var
103
+
104
+ def get_H(self,var=...):
105
+ return _get_H(self.global_state["P"], eigval_fn=None)
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, apply_transform
8
+ from ...core import Chainable, Module, apply_transform, Var
9
9
  from ...utils import TensorList, vec_to_tensors
10
10
  from ...utils.derivatives import (
11
11
  flatten_jacobian,
@@ -50,7 +50,88 @@ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_n
50
50
  return None
51
51
 
52
52
 
53
+ def _get_loss_grad_and_hessian(var: Var, hessian_method:str, vectorize:bool):
54
+ """returns (loss, g_list, H). Also sets var.loss and var.grad.
55
+ If hessian_method isn't 'autograd', loss is not set and returned as None"""
56
+ closure = var.closure
57
+ if closure is None:
58
+ raise RuntimeError("Second order methods requires a closure to be provided to the `step` method.")
53
59
 
60
+ params = var.params
61
+
62
+ # ------------------------ calculate grad and hessian ------------------------ #
63
+ loss = None
64
+ if hessian_method == 'autograd':
65
+ with torch.enable_grad():
66
+ loss = var.loss = var.loss_approx = closure(False)
67
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
68
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
69
+ var.grad = g_list
70
+ H = flatten_jacobian(H_list)
71
+
72
+ elif hessian_method in ('func', 'autograd.functional'):
73
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
74
+ with torch.enable_grad():
75
+ g_list = var.get_grad(retain_graph=True)
76
+ H = hessian_mat(partial(closure, backward=False), params,
77
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
78
+
79
+ else:
80
+ raise ValueError(hessian_method)
81
+
82
+ return loss, g_list, H
83
+
84
+ def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None) -> torch.Tensor:
85
+ """returns the update tensor, then do vec_to_tensor(update, params)"""
86
+ params = var.params
87
+
88
+ if damping != 0:
89
+ H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
90
+
91
+ # -------------------------------- inner step -------------------------------- #
92
+ update = var.get_update()
93
+ if inner is not None:
94
+ update = apply_transform(inner, update, params=params, grads=var.grad, loss=var.loss, var=var)
95
+
96
+ g = torch.cat([t.ravel() for t in update])
97
+ if g_proj is not None: g = g_proj(g)
98
+
99
+ # ----------------------------------- solve ---------------------------------- #
100
+ update = None
101
+
102
+ if H_tfm is not None:
103
+ ret = H_tfm(H, g)
104
+
105
+ if isinstance(ret, torch.Tensor):
106
+ update = ret
107
+
108
+ else: # returns (H, is_inv)
109
+ H, is_inv = ret
110
+ if is_inv: update = H @ g
111
+
112
+ if eigval_fn is not None:
113
+ update = _eigh_solve(H, g, eigval_fn, search_negative=False)
114
+
115
+ if update is None and use_lstsq: update = _least_squares_solve(H, g)
116
+ if update is None: update = _cholesky_solve(H, g)
117
+ if update is None: update = _lu_solve(H, g)
118
+ if update is None: update = _least_squares_solve(H, g)
119
+
120
+ return update
121
+
122
+ def _get_H(H: torch.Tensor, eigval_fn):
123
+ if eigval_fn is not None:
124
+ try:
125
+ L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
126
+ L: torch.Tensor = eigval_fn(L)
127
+ H = Q @ L.diag_embed() @ Q.mH
128
+ H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
129
+ return DenseWithInverse(H, H_inv)
130
+
131
+ except torch.linalg.LinAlgError:
132
+ pass
133
+
134
+ return Dense(H)
54
135
 
55
136
  class Newton(Module):
56
137
  """Exact newton's method via autograd.
@@ -81,7 +162,6 @@ class Newton(Module):
81
162
  how to calculate hessian. Defaults to "autograd".
82
163
  vectorize (bool, optional):
83
164
  whether to enable vectorized hessian. Defaults to True.
84
- inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
85
165
  H_tfm (Callable | None, optional):
86
166
  optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
87
167
 
@@ -94,6 +174,7 @@ class Newton(Module):
94
174
  eigval_fn (Callable | None, optional):
95
175
  optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
96
176
  If this is specified, eigendecomposition will be used to invert the hessian.
177
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
97
178
 
98
179
  # See also
99
180
 
@@ -111,10 +192,9 @@ class Newton(Module):
111
192
  The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
112
193
  Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
113
194
 
114
- Additionally, if ``eigval_fn`` is specified or ``search_negative`` is ``True``,
115
- eigendecomposition of the hessian is computed, ``eigval_fn`` is applied to the eigenvalues,
116
- and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues.
117
- This is more generally more computationally expensive.
195
+ Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
196
+ ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive,
197
+ but not by much
118
198
 
119
199
  ## Handling non-convexity
120
200
 
@@ -167,16 +247,15 @@ class Newton(Module):
167
247
  def __init__(
168
248
  self,
169
249
  damping: float = 0,
170
- search_negative: bool = False,
171
250
  use_lstsq: bool = False,
172
251
  update_freq: int = 1,
173
252
  hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
174
253
  vectorize: bool = True,
175
- inner: Chainable | None = None,
176
254
  H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
177
255
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
256
+ inner: Chainable | None = None,
178
257
  ):
179
- defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, search_negative=search_negative, update_freq=update_freq)
258
+ defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
180
259
  super().__init__(defaults)
181
260
 
182
261
  if inner is not None:
@@ -184,200 +263,31 @@ class Newton(Module):
184
263
 
185
264
  @torch.no_grad
186
265
  def update(self, var):
187
- params = TensorList(var.params)
188
- closure = var.closure
189
- if closure is None: raise RuntimeError('NewtonCG requires closure')
190
-
191
- settings = self.settings[params[0]]
192
- damping = settings['damping']
193
- hessian_method = settings['hessian_method']
194
- vectorize = settings['vectorize']
195
- update_freq = settings['update_freq']
196
-
197
266
  step = self.global_state.get('step', 0)
198
267
  self.global_state['step'] = step + 1
199
268
 
200
- g_list = var.grad
201
- H = None
202
- if step % update_freq == 0:
203
- # ------------------------ calculate grad and hessian ------------------------ #
204
- if hessian_method == 'autograd':
205
- with torch.enable_grad():
206
- loss = var.loss = var.loss_approx = closure(False)
207
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
208
- g_list = [t[0] for t in g_list] # remove leading dim from loss
209
- var.grad = g_list
210
- H = flatten_jacobian(H_list)
211
-
212
- elif hessian_method in ('func', 'autograd.functional'):
213
- strat = 'forward-mode' if vectorize else 'reverse-mode'
214
- with torch.enable_grad():
215
- g_list = var.get_grad(retain_graph=True)
216
- H = hessian_mat(partial(closure, backward=False), params,
217
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
218
-
219
- else:
220
- raise ValueError(hessian_method)
221
-
222
- if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
223
- self.global_state['H'] = H
269
+ if step % self.defaults['update_freq'] == 0:
270
+ loss, g_list, self.global_state['H'] = _get_loss_grad_and_hessian(
271
+ var, self.defaults['hessian_method'], self.defaults['vectorize']
272
+ )
224
273
 
225
274
  @torch.no_grad
226
275
  def apply(self, var):
227
- H = self.global_state["H"]
228
-
229
276
  params = var.params
230
- settings = self.settings[params[0]]
231
- search_negative = settings['search_negative']
232
- H_tfm = settings['H_tfm']
233
- eigval_fn = settings['eigval_fn']
234
- use_lstsq = settings['use_lstsq']
235
-
236
- # -------------------------------- inner step -------------------------------- #
237
- update = var.get_update()
238
- if 'inner' in self.children:
239
- update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
240
-
241
- g = torch.cat([t.ravel() for t in update])
242
-
243
- # ----------------------------------- solve ---------------------------------- #
244
- update = None
245
- if H_tfm is not None:
246
- ret = H_tfm(H, g)
247
-
248
- if isinstance(ret, torch.Tensor):
249
- update = ret
250
-
251
- else: # returns (H, is_inv)
252
- H, is_inv = ret
253
- if is_inv: update = H @ g
254
-
255
- if search_negative or (eigval_fn is not None):
256
- update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
257
-
258
- if update is None and use_lstsq: update = _least_squares_solve(H, g)
259
- if update is None: update = _cholesky_solve(H, g)
260
- if update is None: update = _lu_solve(H, g)
261
- if update is None: update = _least_squares_solve(H, g)
277
+ update = _newton_step(
278
+ var=var,
279
+ H = self.global_state["H"],
280
+ damping=self.defaults["damping"],
281
+ inner=self.children.get("inner", None),
282
+ H_tfm=self.defaults["H_tfm"],
283
+ eigval_fn=self.defaults["eigval_fn"],
284
+ use_lstsq=self.defaults["use_lstsq"],
285
+ )
262
286
 
263
287
  var.update = vec_to_tensors(update, params)
264
288
 
265
289
  return var
266
290
 
267
- def get_H(self,var):
268
- H = self.global_state["H"]
269
- settings = self.defaults
270
- if settings['eigval_fn'] is not None:
271
- try:
272
- L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
273
- L = settings['eigval_fn'](L)
274
- H = Q @ L.diag_embed() @ Q.mH
275
- H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
276
- return DenseWithInverse(H, H_inv)
277
-
278
- except torch.linalg.LinAlgError:
279
- pass
280
-
281
- return Dense(H)
282
-
283
-
284
- class InverseFreeNewton(Module):
285
- """Inverse-free newton's method
286
-
287
- .. note::
288
- In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
289
-
290
- .. note::
291
- This module requires the a closure passed to the optimizer step,
292
- as it needs to re-evaluate the loss and gradients for calculating the hessian.
293
- The closure must accept a ``backward`` argument (refer to documentation).
294
-
295
- .. warning::
296
- this uses roughly O(N^2) memory.
297
-
298
- Reference
299
- Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
300
- """
301
- def __init__(
302
- self,
303
- update_freq: int = 1,
304
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
305
- vectorize: bool = True,
306
- inner: Chainable | None = None,
307
- ):
308
- defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
309
- super().__init__(defaults)
310
-
311
- if inner is not None:
312
- self.set_child('inner', inner)
313
-
314
- @torch.no_grad
315
- def update(self, var):
316
- params = TensorList(var.params)
317
- closure = var.closure
318
- if closure is None: raise RuntimeError('NewtonCG requires closure')
319
-
320
- settings = self.settings[params[0]]
321
- hessian_method = settings['hessian_method']
322
- vectorize = settings['vectorize']
323
- update_freq = settings['update_freq']
324
-
325
- step = self.global_state.get('step', 0)
326
- self.global_state['step'] = step + 1
327
-
328
- g_list = var.grad
329
- Y = None
330
- if step % update_freq == 0:
331
- # ------------------------ calculate grad and hessian ------------------------ #
332
- if hessian_method == 'autograd':
333
- with torch.enable_grad():
334
- loss = var.loss = var.loss_approx = closure(False)
335
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
336
- g_list = [t[0] for t in g_list] # remove leading dim from loss
337
- var.grad = g_list
338
- H = flatten_jacobian(H_list)
339
-
340
- elif hessian_method in ('func', 'autograd.functional'):
341
- strat = 'forward-mode' if vectorize else 'reverse-mode'
342
- with torch.enable_grad():
343
- g_list = var.get_grad(retain_graph=True)
344
- H = hessian_mat(partial(closure, backward=False), params,
345
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
346
-
347
- else:
348
- raise ValueError(hessian_method)
349
-
350
- self.global_state["H"] = H
351
-
352
- # inverse free part
353
- if 'Y' not in self.global_state:
354
- num = H.T
355
- denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
356
- finfo = torch.finfo(H.dtype)
357
- Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
358
-
359
- else:
360
- Y = self.global_state['Y']
361
- I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
362
- I -= H @ Y
363
- Y = self.global_state['Y'] = Y @ I
364
-
365
-
366
- def apply(self, var):
367
- Y = self.global_state["Y"]
368
- params = var.params
369
-
370
- # -------------------------------- inner step -------------------------------- #
371
- update = var.get_update()
372
- if 'inner' in self.children:
373
- update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
374
-
375
- g = torch.cat([t.ravel() for t in update])
376
-
377
- # ----------------------------------- solve ---------------------------------- #
378
- var.update = vec_to_tensors(Y@g, params)
379
-
380
- return var
291
+ def get_H(self,var=...):
292
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
381
293
 
382
- def get_H(self,var):
383
- return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
@@ -193,7 +193,7 @@ class NystromPCG(Module):
193
193
  self,
194
194
  sketch_size: int,
195
195
  maxiter=None,
196
- tol=1e-3,
196
+ tol=1e-8,
197
197
  reg: float = 1e-6,
198
198
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
199
199
  h=1e-3,