torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. tests/test_opts.py +4 -10
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +12 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/conjugate_gradient/cg.py +16 -16
  16. torchzero/modules/experimental/__init__.py +1 -0
  17. torchzero/modules/experimental/newtonnewton.py +5 -5
  18. torchzero/modules/experimental/spsa1.py +93 -0
  19. torchzero/modules/functional.py +7 -0
  20. torchzero/modules/grad_approximation/__init__.py +1 -1
  21. torchzero/modules/grad_approximation/forward_gradient.py +2 -5
  22. torchzero/modules/grad_approximation/rfdm.py +27 -110
  23. torchzero/modules/line_search/__init__.py +1 -1
  24. torchzero/modules/line_search/_polyinterp.py +3 -1
  25. torchzero/modules/line_search/adaptive.py +3 -3
  26. torchzero/modules/line_search/backtracking.py +1 -1
  27. torchzero/modules/line_search/interpolation.py +160 -0
  28. torchzero/modules/line_search/line_search.py +11 -20
  29. torchzero/modules/line_search/scipy.py +15 -3
  30. torchzero/modules/line_search/strong_wolfe.py +3 -5
  31. torchzero/modules/misc/misc.py +2 -2
  32. torchzero/modules/misc/multistep.py +13 -13
  33. torchzero/modules/quasi_newton/__init__.py +2 -0
  34. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  35. torchzero/modules/quasi_newton/sg2.py +292 -0
  36. torchzero/modules/restarts/restars.py +5 -4
  37. torchzero/modules/second_order/__init__.py +6 -3
  38. torchzero/modules/second_order/ifn.py +89 -0
  39. torchzero/modules/second_order/inm.py +105 -0
  40. torchzero/modules/second_order/newton.py +103 -193
  41. torchzero/modules/second_order/newton_cg.py +86 -110
  42. torchzero/modules/second_order/nystrom.py +1 -1
  43. torchzero/modules/second_order/rsn.py +227 -0
  44. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  45. torchzero/modules/trust_region/trust_cg.py +6 -4
  46. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  47. torchzero/modules/zeroth_order/__init__.py +1 -1
  48. torchzero/modules/zeroth_order/cd.py +1 -238
  49. torchzero/utils/derivatives.py +19 -19
  50. torchzero/utils/linalg/linear_operator.py +50 -2
  51. torchzero/utils/optimizer.py +2 -2
  52. torchzero/utils/python_tools.py +1 -0
  53. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  54. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
  55. torchzero/modules/higher_order/__init__.py +0 -1
  56. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  57. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  58. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
@@ -13,21 +13,12 @@ from ..trust_region.trust_region import default_radius
13
13
  class NewtonCG(Module):
14
14
  """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
15
15
 
16
- This optimizer implements Newton's method using a matrix-free conjugate
17
- gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
18
- forming the full Hessian matrix, it only requires Hessian-vector products
19
- (HVPs). These can be calculated efficiently using automatic
20
- differentiation or approximated using finite differences.
16
+ Notes:
17
+ * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
21
18
 
22
- .. note::
23
- In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
19
+ * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
24
20
 
25
- .. note::
26
- This module requires the a closure passed to the optimizer step,
27
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
28
- The closure must accept a ``backward`` argument (refer to documentation).
29
-
30
- .. warning::
21
+ Warning:
31
22
  CG may fail if hessian is not positive-definite.
32
23
 
33
24
  Args:
@@ -66,26 +57,24 @@ class NewtonCG(Module):
66
57
  NewtonCG will attempt to apply preconditioning to the output of this module.
67
58
 
68
59
  Examples:
69
- Newton-CG with a backtracking line search:
70
-
71
- .. code-block:: python
72
-
73
- opt = tz.Modular(
74
- model.parameters(),
75
- tz.m.NewtonCG(),
76
- tz.m.Backtracking()
77
- )
78
-
79
- Truncated Newton method (useful for large-scale problems):
80
-
81
- .. code-block:: python
82
-
83
- opt = tz.Modular(
84
- model.parameters(),
85
- tz.m.NewtonCG(maxiter=10, warm_start=True),
86
- tz.m.Backtracking()
87
- )
88
-
60
+ Newton-CG with a backtracking line search:
61
+
62
+ ```python
63
+ opt = tz.Modular(
64
+ model.parameters(),
65
+ tz.m.NewtonCG(),
66
+ tz.m.Backtracking()
67
+ )
68
+ ```
69
+
70
+ Truncated Newton method (useful for large-scale problems):
71
+ ```
72
+ opt = tz.Modular(
73
+ model.parameters(),
74
+ tz.m.NewtonCG(maxiter=10),
75
+ tz.m.Backtracking()
76
+ )
77
+ ```
89
78
 
90
79
  """
91
80
  def __init__(
@@ -95,7 +84,7 @@ class NewtonCG(Module):
95
84
  reg: float = 1e-8,
96
85
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
97
86
  solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
98
- h: float = 1e-3,
87
+ h: float = 1e-3, # tuned 1e-4 or 1e-3
99
88
  miniter:int = 1,
100
89
  warm_start=False,
101
90
  inner: Chainable | None = None,
@@ -187,96 +176,95 @@ class NewtonCG(Module):
187
176
 
188
177
 
189
178
  class NewtonCGSteihaug(Module):
190
- """Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
179
+ """Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
191
180
 
192
- This optimizer implements Newton's method using a matrix-free conjugate
193
- gradient (CG) solver to approximate the search direction. Instead of
194
- forming the full Hessian matrix, it only requires Hessian-vector products
195
- (HVPs). These can be calculated efficiently using automatic
196
- differentiation or approximated using finite differences.
181
+ Notes:
182
+ * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
197
183
 
198
- .. note::
199
- In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
200
-
201
- .. note::
202
- This module requires the a closure passed to the optimizer step,
203
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
204
- The closure must accept a ``backward`` argument (refer to documentation).
205
-
206
- .. warning::
207
- CG may fail if hessian is not positive-definite.
184
+ * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
208
185
 
209
186
  Args:
210
- maxiter (int | None, optional):
211
- Maximum number of iterations for the conjugate gradient solver.
212
- By default, this is set to the number of dimensions in the
213
- objective function, which is the theoretical upper bound for CG
214
- convergence. Setting this to a smaller value (truncated Newton)
215
- can still generate good search directions. Defaults to None.
216
187
  eta (float, optional):
217
- whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
218
- nplus (float, optional):
219
- trust region multiplier on successful steps.
220
- nminus (float, optional):
221
- trust region multiplier on unsuccessful steps.
222
- init (float, optional): initial trust region.
188
+ if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
189
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
190
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
191
+ rho_good (float, optional):
192
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
193
+ rho_bad (float, optional):
194
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
195
+ init (float, optional): Initial trust region value. Defaults to 1.
196
+ max_attempts (max_attempts, optional):
197
+ maximum number of trust radius reductions per step. A zero update vector is returned when
198
+ this limit is exceeded. Defaults to 10.
199
+ max_history (int, optional):
200
+ CG will store this many intermediate solutions, reusing them when trust radius is reduced
201
+ instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
202
+ boundary_tol (float | None, optional):
203
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
204
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
205
+
206
+ maxiter (int | None, optional):
207
+ maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
208
+ miniter (int, optional):
209
+ minimal number of CG iterations. This prevents making no progress
223
210
  tol (float, optional):
224
- Relative tolerance for the conjugate gradient solver to determine
225
- convergence. Defaults to 1e-4.
226
- reg (float, optional):
227
- Regularization parameter (damping) added to the Hessian diagonal.
228
- This helps ensure the system is positive-definite. Defaults to 1e-8.
211
+ terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
212
+ when initial guess is below tolerance. Defaults to 1.
213
+ reg (float, optional): hessian regularization. Defaults to 1e-8.
214
+ solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
215
+ adapt_tol (bool, optional):
216
+ if True, whenever trust radius collapses to smallest representable number,
217
+ the tolerance is multiplied by 0.1. Defaults to True.
218
+ npc_terminate (bool, optional):
219
+ whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
220
+
229
221
  hvp_method (str, optional):
230
- Determines how Hessian-vector products are evaluated.
222
+ either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
223
+ h (float, optional): finite difference step size. Defaults to 1e-3.
231
224
 
232
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
233
- This requires creating a graph for the gradient.
234
- - ``"forward"``: Use a forward finite difference formula to
235
- approximate the HVP. This requires one extra gradient evaluation.
236
- - ``"central"``: Use a central finite difference formula for a
237
- more accurate HVP approximation. This requires two extra
238
- gradient evaluations.
239
- Defaults to "autograd".
240
- h (float, optional):
241
- The step size for finite differences if :code:`hvp_method` is
242
- ``"forward"`` or ``"central"``. Defaults to 1e-3.
243
225
  inner (Chainable | None, optional):
244
- NewtonCG will attempt to apply preconditioning to the output of this module.
226
+ applies preconditioning to output of this module. Defaults to None.
245
227
 
246
- Examples:
247
- Trust-region Newton-CG:
248
-
249
- .. code-block:: python
228
+ ### Examples:
229
+ Trust-region Newton-CG:
250
230
 
251
- opt = tz.Modular(
252
- model.parameters(),
253
- tz.m.NewtonCGSteihaug(),
254
- )
231
+ ```python
232
+ opt = tz.Modular(
233
+ model.parameters(),
234
+ tz.m.NewtonCGSteihaug(),
235
+ )
236
+ ```
255
237
 
256
- Reference:
238
+ ### Reference:
257
239
  Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
258
240
  """
259
241
  def __init__(
260
242
  self,
261
- maxiter: int | None = None,
243
+ # trust region settings
262
244
  eta: float= 0.0,
263
245
  nplus: float = 3.5,
264
246
  nminus: float = 0.25,
265
247
  rho_good: float = 0.99,
266
248
  rho_bad: float = 1e-4,
267
249
  init: float = 1,
268
- tol: float = 1e-8,
269
- reg: float = 1e-8,
270
- hvp_method: Literal["forward", "central"] = "forward",
271
- solver: Literal['cg', "minres"] = 'cg',
272
- h: float = 1e-3,
273
250
  max_attempts: int = 100,
274
251
  max_history: int = 100,
275
- boundary_tol: float = 1e-1,
252
+ boundary_tol: float = 1e-6, # tuned
253
+
254
+ # cg settings
255
+ maxiter: int | None = None,
276
256
  miniter: int = 1,
277
- rms_beta: float | None = None,
257
+ tol: float = 1e-8,
258
+ reg: float = 1e-8,
259
+ solver: Literal['cg', "minres"] = 'cg',
278
260
  adapt_tol: bool = True,
279
261
  npc_terminate: bool = False,
262
+
263
+ # hvp settings
264
+ hvp_method: Literal["forward", "central"] = "central",
265
+ h: float = 1e-3, # tuned 1e-4 or 1e-3
266
+
267
+ # inner
280
268
  inner: Chainable | None = None,
281
269
  ):
282
270
  defaults = locals().copy()
@@ -336,19 +324,8 @@ class NewtonCGSteihaug(Module):
336
324
  raise ValueError(hvp_method)
337
325
 
338
326
 
339
- # ------------------------- update RMS preconditioner ------------------------ #
340
- b = var.get_update()
341
- P_mm = None
342
- rms_beta = self.defaults["rms_beta"]
343
- if rms_beta is not None:
344
- exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
345
- exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
346
- exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
347
- def _P_mm(x):
348
- return x / exp_avg_sq_sqrt
349
- P_mm = _P_mm
350
-
351
327
  # -------------------------------- inner step -------------------------------- #
328
+ b = var.get_update()
352
329
  if 'inner' in self.children:
353
330
  b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
354
331
  b = as_tensorlist(b)
@@ -392,7 +369,6 @@ class NewtonCGSteihaug(Module):
392
369
  miniter=miniter,
393
370
  npc_terminate=npc_terminate,
394
371
  history_size=max_history,
395
- P_mm=P_mm,
396
372
  )
397
373
 
398
374
  elif solver == 'minres':
@@ -193,7 +193,7 @@ class NystromPCG(Module):
193
193
  self,
194
194
  sketch_size: int,
195
195
  maxiter=None,
196
- tol=1e-3,
196
+ tol=1e-8,
197
197
  reg: float = 1e-6,
198
198
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
199
199
  h=1e-3,
@@ -0,0 +1,227 @@
1
+ import math
2
+ from collections import deque
3
+ from collections.abc import Callable
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, apply_transform
9
+ from ...utils import Distributions, TensorList, vec_to_tensors
10
+ from ...utils.linalg.linear_operator import Sketched
11
+ from .newton import _newton_step
12
+
13
+ def _qr_orthonormalize(A:torch.Tensor):
14
+ m,n = A.shape
15
+ if m < n:
16
+ q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
17
+ return q.T
18
+ else:
19
+ q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
20
+ return q
21
+
22
+ def _orthonormal_sketch(m, n, dtype, device, generator):
23
+ return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
24
+
25
+ def _gaussian_sketch(m, n, dtype, device, generator):
26
+ return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
27
+
28
+ class RSN(Module):
29
+ """Randomized Subspace Newton. Performs a Newton step in a random subspace.
30
+
31
+ Args:
32
+ sketch_size (int):
33
+ size of the random sketch. This many hessian-vector products will need to be evaluated each step.
34
+ sketch_type (str, optional):
35
+ - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
36
+ - "gaussian" - random gaussian (not orthonormal) basis.
37
+ - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
38
+ - "mixed" - random orthonormal basis but with three directions set to gradient, slow EMA and fast EMA (default).
39
+ damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
40
+ hvp_method (str, optional):
41
+ How to compute hessian-matrix product:
42
+ - "batched" - uses batched autograd
43
+ - "autograd" - uses unbatched autograd
44
+ - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
45
+ - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
46
+
47
+ . Defaults to "batched".
48
+ h (float, optional): finite difference step size. Defaults to 1e-2.
49
+ use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
50
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
51
+ H_tfm (Callable | None, optional):
52
+ optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
53
+
54
+ must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
55
+ which must be True if transform inverted the hessian and False otherwise.
56
+
57
+ Or it returns a single tensor which is used as the update.
58
+
59
+ Defaults to None.
60
+ eigval_fn (Callable | None, optional):
61
+ optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
62
+ If this is specified, eigendecomposition will be used to invert the hessian.
63
+ seed (int | None, optional): seed for random generator. Defaults to None.
64
+ inner (Chainable | None, optional): preconditions output of this module. Defaults to None.
65
+
66
+ ### Examples
67
+
68
+ RSN with line search
69
+ ```python
70
+ opt = tz.Modular(
71
+ model.parameters(),
72
+ tz.m.RSN(),
73
+ tz.m.Backtracking()
74
+ )
75
+ ```
76
+
77
+ RSN with trust region
78
+ ```python
79
+ opt = tz.Modular(
80
+ model.parameters(),
81
+ tz.m.LevenbergMarquardt(tz.m.RSN()),
82
+ )
83
+ ```
84
+
85
+
86
+ References:
87
+ 1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
88
+ 2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ sketch_size: int,
94
+ sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
95
+ damping:float=0,
96
+ hvp_method: Literal["batched", "autograd", "forward", "central"] = "batched",
97
+ h: float = 1e-2,
98
+ use_lstsq: bool = True,
99
+ update_freq: int = 1,
100
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
101
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
102
+ seed: int | None = None,
103
+ inner: Chainable | None = None,
104
+ ):
105
+ defaults = dict(sketch_size=sketch_size, sketch_type=sketch_type,seed=seed,hvp_method=hvp_method, h=h, damping=damping, use_lstsq=use_lstsq, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
106
+ super().__init__(defaults)
107
+
108
+ if inner is not None:
109
+ self.set_child("inner", inner)
110
+
111
+ @torch.no_grad
112
+ def update(self, var):
113
+ step = self.global_state.get('step', 0)
114
+ self.global_state['step'] = step + 1
115
+
116
+ if step % self.defaults['update_freq'] == 0:
117
+
118
+ closure = var.closure
119
+ if closure is None:
120
+ raise RuntimeError("RSN requires closure")
121
+ params = var.params
122
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
123
+
124
+ ndim = sum(p.numel() for p in params)
125
+
126
+ device=params[0].device
127
+ dtype=params[0].dtype
128
+
129
+ # sample sketch matrix S: (ndim, sketch_size)
130
+ sketch_size = min(self.defaults["sketch_size"], ndim)
131
+ sketch_type = self.defaults["sketch_type"]
132
+ hvp_method = self.defaults["hvp_method"]
133
+
134
+ if sketch_type in ('normal', 'gaussian'):
135
+ S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
136
+
137
+ elif sketch_type == 'orthonormal':
138
+ S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
139
+
140
+ elif sketch_type == 'common_directions':
141
+ # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
142
+ g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
143
+ g = torch.cat([t.ravel() for t in g_list])
144
+
145
+ # initialize directions deque
146
+ if "directions" not in self.global_state:
147
+
148
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
149
+ if g_norm < torch.finfo(g.dtype).tiny * 2:
150
+ g = torch.randn_like(g)
151
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
152
+
153
+ self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
154
+ S = self.global_state["directions"][0].unsqueeze(1)
155
+
156
+ # add new steepest descent direction orthonormal to existing columns
157
+ else:
158
+ S = torch.stack(tuple(self.global_state["directions"]), dim=1)
159
+ p = g - S @ (S.T @ g)
160
+ p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
161
+ if p_norm > torch.finfo(p.dtype).tiny * 2:
162
+ p = p / p_norm
163
+ self.global_state["directions"].append(p)
164
+ S = torch.cat([S, p.unsqueeze(1)], dim=1)
165
+
166
+ elif sketch_type == "mixed":
167
+ g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
168
+ g = torch.cat([t.ravel() for t in g_list])
169
+
170
+ if "slow_ema" not in self.global_state:
171
+ self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
172
+ self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
173
+
174
+ slow_ema = self.global_state["slow_ema"]
175
+ fast_ema = self.global_state["fast_ema"]
176
+ slow_ema.lerp_(g, 0.001)
177
+ fast_ema.lerp_(g, 0.1)
178
+
179
+ S = torch.stack([g, slow_ema, fast_ema], dim=1)
180
+ if sketch_size > 3:
181
+ S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
182
+ S = torch.cat([S, S_random], dim=1)
183
+
184
+ S = _qr_orthonormalize(S)
185
+
186
+ else:
187
+ raise ValueError(f'Unknown sketch_type {sketch_type}')
188
+
189
+ # form sketched hessian
190
+ HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
191
+ H_sketched = S.T @ HS
192
+
193
+ self.global_state["H_sketched"] = H_sketched
194
+ self.global_state["S"] = S
195
+
196
+ def apply(self, var):
197
+ S: torch.Tensor = self.global_state["S"]
198
+ d_proj = _newton_step(
199
+ var=var,
200
+ H=self.global_state["H_sketched"],
201
+ damping=self.defaults["damping"],
202
+ inner=self.children.get("inner", None),
203
+ H_tfm=self.defaults["H_tfm"],
204
+ eigval_fn=self.defaults["eigval_fn"],
205
+ use_lstsq=self.defaults["use_lstsq"],
206
+ g_proj = lambda g: S.T @ g
207
+ )
208
+ d = S @ d_proj
209
+ var.update = vec_to_tensors(d, var.params)
210
+
211
+ return var
212
+
213
+ def get_H(self, var=...):
214
+ eigval_fn = self.defaults["eigval_fn"]
215
+ H_sketched: torch.Tensor = self.global_state["H_sketched"]
216
+ S: torch.Tensor = self.global_state["S"]
217
+
218
+ if eigval_fn is not None:
219
+ try:
220
+ L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
221
+ L: torch.Tensor = eigval_fn(L)
222
+ H_sketched = Q @ L.diag_embed() @ Q.mH
223
+
224
+ except torch.linalg.LinAlgError:
225
+ pass
226
+
227
+ return Sketched(S, H_sketched)
@@ -14,13 +14,13 @@ class LevenbergMarquardt(TrustRegionBase):
14
14
  hess_module (Module | None, optional):
15
15
  A module that maintains a hessian approximation (not hessian inverse!).
16
16
  This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
17
- When using quasi-newton methods, set `inverse=False` when constructing them.
17
+ When using quasi-newton methods, set ``inverse=False`` when constructing them.
18
18
  y (float, optional):
19
19
  when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
20
20
  is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
21
21
  eta (float, optional):
22
22
  if ratio of actual to predicted rediction is larger than this, step is accepted.
23
- When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
23
+ When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
24
24
  nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
25
25
  nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
26
26
  rho_good (float, optional):
@@ -60,17 +60,19 @@ class TrustCG(TrustRegionBase):
60
60
  nminus: float = 0.25,
61
61
  rho_good: float = 0.99,
62
62
  rho_bad: float = 1e-4,
63
- boundary_tol: float | None = 1e-1,
63
+ boundary_tol: float | None = 1e-6, # tuned
64
64
  init: float = 1,
65
65
  max_attempts: int = 10,
66
66
  radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
67
67
  reg: float = 0,
68
- cg_tol: float = 1e-4,
68
+ maxiter: int | None = None,
69
+ miniter: int = 1,
70
+ cg_tol: float = 1e-8,
69
71
  prefer_exact: bool = True,
70
72
  update_freq: int = 1,
71
73
  inner: Chainable | None = None,
72
74
  ):
73
- defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol)
75
+ defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
74
76
  super().__init__(
75
77
  defaults=defaults,
76
78
  hess_module=hess_module,
@@ -93,5 +95,5 @@ class TrustCG(TrustRegionBase):
93
95
  if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
94
96
  return H.solve_bounded(g, radius)
95
97
 
96
- x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], tol=settings["cg_tol"])
98
+ x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
97
99
  return x