torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,22 @@
1
1
  from collections.abc import Callable
2
- from typing import Literal
2
+ from typing import Any
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Transform, Objective, HessianMethod, Module
7
- from ...utils import vec_to_tensors
8
- from ...linalg.linear_operator import Dense, DenseWithInverse
6
+ from ...core import Chainable, Transform, Objective, HessianMethod
7
+ from ...utils import vec_to_tensors_
8
+ from ...linalg.linear_operator import Dense, DenseWithInverse, Eigendecomposition
9
+ from ...linalg import torch_linalg
9
10
 
10
-
11
- def _lu_solve(H: torch.Tensor, g: torch.Tensor):
11
+ def _try_lu_solve(H: torch.Tensor, g: torch.Tensor):
12
12
  try:
13
- x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
13
+ x, info = torch_linalg.solve_ex(H, g, retry_float64=True)
14
14
  if info == 0: return x
15
15
  return None
16
16
  except RuntimeError:
17
17
  return None
18
18
 
19
- def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
19
+ def _try_cholesky_solve(H: torch.Tensor, g: torch.Tensor):
20
20
  L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
21
21
  if info == 0:
22
22
  return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
@@ -25,77 +25,91 @@ def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
25
25
  def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
26
26
  return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
27
27
 
28
- def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
29
- try:
30
- L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
31
- if tfm is not None: L = tfm(L)
32
- if search_negative and L[0] < 0:
33
- neg_mask = L < 0
34
- Q_neg = Q[:, neg_mask] * L[neg_mask]
35
- return (Q_neg * (g @ Q_neg).sign()).mean(1)
36
-
37
- return Q @ ((Q.mH @ g) / L)
38
-
39
- except torch.linalg.LinAlgError:
40
- return None
41
-
42
- def _newton_step(objective: Objective, H: torch.Tensor, damping:float, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None, no_inner: Module | None = None) -> torch.Tensor:
43
- """INNER SHOULD BE NONE IN MOST CASES! Because Transform already has inner.
44
- Returns the update tensor, then do vec_to_tensor(update, params)"""
45
- # -------------------------------- inner step -------------------------------- #
46
- if no_inner is not None:
47
- objective = no_inner.step(objective)
48
-
49
- update = objective.get_updates()
50
-
51
- g = torch.cat([t.ravel() for t in update])
52
- if g_proj is not None: g = g_proj(g)
53
-
54
- # ----------------------------------- solve ---------------------------------- #
55
- update = None
56
-
28
+ def _newton_update_state_(
29
+ state: dict,
30
+ H: torch.Tensor,
31
+ damping: float,
32
+ eigval_fn: Callable | None,
33
+ precompute_inverse: bool,
34
+ use_lstsq: bool,
35
+ ):
36
+ """used in most hessian-based modules"""
37
+ # add damping
57
38
  if damping != 0:
58
- H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
59
-
60
- if H_tfm is not None:
61
- ret = H_tfm(H, g)
62
-
63
- if isinstance(ret, torch.Tensor):
64
- update = ret
39
+ reg = torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(damping)
40
+ H += reg
65
41
 
66
- else: # returns (H, is_inv)
67
- H, is_inv = ret
68
- if is_inv: update = H @ g
69
-
70
- if eigval_fn is not None:
71
- update = _eigh_solve(H, g, eigval_fn, search_negative=False)
72
-
73
- if update is None and use_lstsq: update = _least_squares_solve(H, g)
74
- if update is None: update = _cholesky_solve(H, g)
75
- if update is None: update = _lu_solve(H, g)
76
- if update is None: update = _least_squares_solve(H, g)
77
-
78
- return update
79
-
80
- def _get_H(H: torch.Tensor, eigval_fn):
42
+ # if eigval_fn is given, we don't need H or H_inv, we store factors
81
43
  if eigval_fn is not None:
82
- try:
83
- L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
84
- L: torch.Tensor = eigval_fn(L)
85
- H = Q @ L.diag_embed() @ Q.mH
86
- H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
87
- return DenseWithInverse(H, H_inv)
88
-
89
- except torch.linalg.LinAlgError:
90
- pass
44
+ L, Q = torch_linalg.eigh(H, retry_float64=True)
45
+ L = eigval_fn(L)
46
+ state["L"] = L
47
+ state["Q"] = Q
48
+ return
49
+
50
+ # pre-compute inverse if requested
51
+ # store H to as it is needed for trust regions
52
+ state["H"] = H
53
+ if precompute_inverse:
54
+ if use_lstsq:
55
+ H_inv = torch.linalg.pinv(H) # pylint:disable=not-callable
56
+ else:
57
+ H_inv, _ = torch_linalg.inv_ex(H)
58
+ state["H_inv"] = H_inv
59
+
60
+
61
+ def _newton_solve(
62
+ b: torch.Tensor,
63
+ state: dict[str, torch.Tensor | Any],
64
+ use_lstsq: bool = False,
65
+ ):
66
+ """
67
+ used in most hessian-based modules. state is from ``_newton_update_state_``, in it:
91
68
 
92
- return Dense(H)
69
+ H (torch.Tensor): hessian
70
+ H_inv (torch.Tensor | None): hessian inverse
71
+ L (torch.Tensor | None): eigenvalues (transformed)
72
+ Q (torch.Tensor | None): eigenvectors
73
+ """
74
+ # use eig if provided
75
+ if "L" in state:
76
+ Q = state["Q"]; L = state["L"]
77
+ assert Q is not None
78
+ return Q @ ((Q.mH @ b) / L)
79
+
80
+ # use inverse if cached
81
+ if "H_inv" in state:
82
+ return state["H_inv"] @ b
83
+
84
+ # use hessian
85
+ H = state["H"]
86
+ if use_lstsq: return _least_squares_solve(H, b)
87
+
88
+ dir = None
89
+ if dir is None: dir = _try_cholesky_solve(H, b)
90
+ if dir is None: dir = _try_lu_solve(H, b)
91
+ if dir is None: dir = _least_squares_solve(H, b)
92
+ return dir
93
+
94
+ def _newton_get_H(state: dict[str, torch.Tensor | Any]):
95
+ """used in most hessian-based modules. state is from ``_newton_update_state_``"""
96
+ if "H_inv" in state:
97
+ return DenseWithInverse(state["H"], state["H_inv"])
98
+
99
+ if "L" in state:
100
+ # Eigendecomposition has sligthly different solve_plus_diag
101
+ # I am pretty sure it should be very close and it uses no solves
102
+ # best way to test is to try cubic regularization with this
103
+ return Eigendecomposition(state["L"], state["Q"], use_nystrom=False)
104
+
105
+ return Dense(state["H"])
93
106
 
94
107
  class Newton(Transform):
95
- """Exact newton's method via autograd.
108
+ """Exact Newton's method via autograd.
96
109
 
97
110
  Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
98
111
  The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
112
+
99
113
  ``g`` can be output of another module, if it is specifed in ``inner`` argument.
100
114
 
101
115
  Note:
@@ -107,27 +121,19 @@ class Newton(Transform):
107
121
  The closure must accept a ``backward`` argument (refer to documentation).
108
122
 
109
123
  Args:
110
- damping (float, optional): tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.
111
- search_negative (bool, Optional):
112
- if True, whenever a negative eigenvalue is detected,
113
- search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
114
- use_lstsq (bool, Optional):
115
- if True, least squares will be used to solve the linear system, this may generate reasonable directions
116
- when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
117
- If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
118
- argument will be ignored.
119
- H_tfm (Callable | None, optional):
120
- optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
121
-
122
- must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
123
- which must be True if transform inverted the hessian and False otherwise.
124
-
125
- Or it returns a single tensor which is used as the update.
126
-
127
- Defaults to None.
124
+ damping (float, optional): tikhonov regularizer value. Defaults to 0.
128
125
  eigval_fn (Callable | None, optional):
129
- optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
126
+ function to apply to eigenvalues, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
130
127
  If this is specified, eigendecomposition will be used to invert the hessian.
128
+ update_freq (int, optional):
129
+ updates hessian every ``update_freq`` steps.
130
+ precompute_inverse (bool, optional):
131
+ if ``True``, whenever hessian is computed, also computes the inverse. This is more efficient
132
+ when ``update_freq`` is large. If ``None``, this is ``True`` if ``update_freq >= 10``.
133
+ use_lstsq (bool, Optional):
134
+ if True, least squares will be used to solve the linear system, this can prevent it from exploding
135
+ when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares.
136
+ If ``eigval_fn`` is specified, eigendecomposition is always used and this argument is ignored.
131
137
  hessian_method (str):
132
138
  Determines how hessian is computed.
133
139
 
@@ -139,17 +145,19 @@ class Newton(Transform):
139
145
  - ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
140
146
  - ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
141
147
  - ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
148
+ - ``"thoad"`` - uses ``thoad`` library, can be significantly faster than pytorch but limited operator coverage.
142
149
 
143
150
  Defaults to ``"batched_autograd"``.
144
151
  h (float, optional):
145
- finite difference step size for "fd_forward" and "fd_central".
152
+ finite difference step size if hessian is compute via finite-difference.
146
153
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
147
154
 
148
155
  # See also
149
156
 
150
- * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products,
157
+ * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products.
151
158
  useful for large scale problems as it doesn't form the full hessian.
152
159
  * ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
160
+ * ``tz.m.ImprovedNewton``: Newton with additional rank one correction to the hessian, can be faster than Newton.
153
161
  * ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
154
162
  * ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
155
163
 
@@ -158,57 +166,48 @@ class Newton(Transform):
158
166
  ## Implementation details
159
167
 
160
168
  ``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
161
- The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
162
- Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
169
+ The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting ``use_lstsq=True``.
163
170
 
164
171
  Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
165
- ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive,
166
- but not by much
172
+ ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.
167
173
 
168
174
  ## Handling non-convexity
169
175
 
170
176
  Standard Newton's method does not handle non-convexity well without some modifications.
171
177
  This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
172
178
 
173
- The first modification to handle non-convexity is to modify the eignevalues to be positive,
179
+ A modification to handle non-convexity is to modify the eignevalues to be positive,
174
180
  for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
175
181
 
176
- Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
177
- This also requires an eigendecomposition.
178
-
179
- The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
180
- but that may be significantly less efficient.
181
-
182
182
  # Examples:
183
183
 
184
184
  Newton's method with backtracking line search
185
185
 
186
186
  ```py
187
- opt = tz.Modular(
187
+ opt = tz.Optimizer(
188
188
  model.parameters(),
189
189
  tz.m.Newton(),
190
190
  tz.m.Backtracking()
191
191
  )
192
192
  ```
193
193
 
194
- Newton preconditioning applied to momentum
194
+ Newton's method for non-convex optimization.
195
195
 
196
196
  ```py
197
- opt = tz.Modular(
197
+ opt = tz.Optimizer(
198
198
  model.parameters(),
199
- tz.m.Newton(inner=tz.m.EMA(0.9)),
200
- tz.m.LR(0.1)
199
+ tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
200
+ tz.m.Backtracking()
201
201
  )
202
202
  ```
203
203
 
204
- Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient,
205
- but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
204
+ Newton preconditioning applied to momentum
206
205
 
207
206
  ```py
208
- opt = tz.Modular(
207
+ opt = tz.Optimizer(
209
208
  model.parameters(),
210
- tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
211
- tz.m.Backtracking()
209
+ tz.m.Newton(inner=tz.m.EMA(0.9)),
210
+ tz.m.LR(0.1)
212
211
  )
213
212
  ```
214
213
 
@@ -216,10 +215,10 @@ class Newton(Transform):
216
215
  def __init__(
217
216
  self,
218
217
  damping: float = 0,
219
- use_lstsq: bool = False,
220
- update_freq: int = 1,
221
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
222
218
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
219
+ update_freq: int = 1,
220
+ precompute_inverse: bool | None = None,
221
+ use_lstsq: bool = False,
223
222
  hessian_method: HessianMethod = "batched_autograd",
224
223
  h: float = 1e-3,
225
224
  inner: Chainable | None = None,
@@ -232,29 +231,32 @@ class Newton(Transform):
232
231
  def update_states(self, objective, states, settings):
233
232
  fs = settings[0]
234
233
 
235
- _, _, self.global_state['H'] = objective.hessian(
236
- hessian_method=fs['hessian_method'],
237
- h=fs['h'],
238
- at_x0=True
234
+ precompute_inverse = fs["precompute_inverse"]
235
+ if precompute_inverse is None:
236
+ precompute_inverse = fs["__update_freq"] >= 10
237
+
238
+ __, _, H = objective.hessian(hessian_method=fs["hessian_method"], h=fs["h"], at_x0=True)
239
+
240
+ _newton_update_state_(
241
+ state = self.global_state,
242
+ H=H,
243
+ damping = fs["damping"],
244
+ eigval_fn = fs["eigval_fn"],
245
+ precompute_inverse = precompute_inverse,
246
+ use_lstsq = fs["use_lstsq"]
239
247
  )
240
248
 
241
249
  @torch.no_grad
242
250
  def apply_states(self, objective, states, settings):
243
- params = objective.params
251
+ updates = objective.get_updates()
244
252
  fs = settings[0]
245
253
 
246
- update = _newton_step(
247
- objective=objective,
248
- H = self.global_state["H"],
249
- damping = fs["damping"],
250
- H_tfm = fs["H_tfm"],
251
- eigval_fn = fs["eigval_fn"],
252
- use_lstsq = fs["use_lstsq"],
253
- )
254
+ b = torch.cat([t.ravel() for t in updates])
255
+ sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
254
256
 
255
- objective.updates = vec_to_tensors(update, params)
257
+ vec_to_tensors_(sol, updates)
256
258
  return objective
257
259
 
258
260
  def get_H(self,objective=...):
259
- return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
261
+ return _newton_get_H(self.global_state)
260
262
 
@@ -57,7 +57,7 @@ class NewtonCG(Transform):
57
57
  Newton-CG with a backtracking line search:
58
58
 
59
59
  ```python
60
- opt = tz.Modular(
60
+ opt = tz.Optimizer(
61
61
  model.parameters(),
62
62
  tz.m.NewtonCG(),
63
63
  tz.m.Backtracking()
@@ -66,7 +66,7 @@ class NewtonCG(Transform):
66
66
 
67
67
  Truncated Newton method (useful for large-scale problems):
68
68
  ```
69
- opt = tz.Modular(
69
+ opt = tz.Optimizer(
70
70
  model.parameters(),
71
71
  tz.m.NewtonCG(maxiter=10),
72
72
  tz.m.Backtracking()
@@ -198,7 +198,7 @@ class NewtonCGSteihaug(Transform):
198
198
  Trust-region Newton-CG:
199
199
 
200
200
  ```python
201
- opt = tz.Modular(
201
+ opt = tz.Optimizer(
202
202
  model.parameters(),
203
203
  tz.m.NewtonCGSteihaug(),
204
204
  )
@@ -1,10 +1,11 @@
1
+ import warnings
1
2
  from typing import Literal
2
3
 
3
4
  import torch
4
5
 
5
6
  from ...core import Chainable, Transform, HVPMethod
6
7
  from ...utils import TensorList, vec_to_tensors
7
- from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg
8
+ from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod
8
9
  from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
9
10
 
10
11
  class NystromSketchAndSolve(Transform):
@@ -19,7 +20,18 @@ class NystromSketchAndSolve(Transform):
19
20
 
20
21
  Args:
21
22
  rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
22
- reg (float, optional): regularization parameter. Defaults to 1e-3.
23
+ reg (float | None, optional):
24
+ scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve
25
+ is used to compute ``(Q diag(L) Q.T + reg*I)x = b``. It is very unstable when ``reg`` is small,
26
+ i.e. smaller than 1e-4. If this is None,``(Q diag(L) Q.T)x = b`` is computed by simply taking
27
+ reciprocal of eigenvalues. Defaults to 1e-3.
28
+ eigv_tol (float, optional):
29
+ all eigenvalues smaller than largest eigenvalue times ``eigv_tol`` are removed. Defaults to None.
30
+ truncate (int | None, optional):
31
+ keeps top ``truncate`` eigenvalues. Defaults to None.
32
+ damping (float, optional): scalar added to eigenvalues. Defaults to 0.
33
+ rdamping (float, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
34
+ update_freq (int, optional): frequency of updating preconditioner. Defaults to 1.
23
35
  hvp_method (str, optional):
24
36
  Determines how Hessian-vector products are computed.
25
37
 
@@ -40,7 +52,7 @@ class NystromSketchAndSolve(Transform):
40
52
  NystromSketchAndSolve with backtracking line search
41
53
 
42
54
  ```py
43
- opt = tz.Modular(
55
+ opt = tz.Optimizer(
44
56
  model.parameters(),
45
57
  tz.m.NystromSketchAndSolve(100),
46
58
  tz.m.Backtracking()
@@ -50,7 +62,7 @@ class NystromSketchAndSolve(Transform):
50
62
  Trust region NystromSketchAndSolve
51
63
 
52
64
  ```py
53
- opt = tz.Modular(
65
+ opt = tz.Optimizer(
54
66
  model.parameters(),
55
67
  tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
56
68
  )
@@ -64,10 +76,15 @@ class NystromSketchAndSolve(Transform):
64
76
  def __init__(
65
77
  self,
66
78
  rank: int,
67
- reg: float = 1e-3,
79
+ reg: float | None = 1e-2,
80
+ eigv_tol: float = 0,
81
+ truncate: int | None = None,
82
+ damping: float = 0,
83
+ rdamping: float = 0,
84
+ update_freq: int = 1,
85
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
68
86
  hvp_method: HVPMethod = "batched_autograd",
69
87
  h: float = 1e-3,
70
- update_freq: int = 1,
71
88
  inner: Chainable | None = None,
72
89
  seed: int | None = None,
73
90
  ):
@@ -92,25 +109,53 @@ class NystromSketchAndSolve(Transform):
92
109
 
93
110
  generator = self.get_generator(params[0].device, seed=fs['seed'])
94
111
  try:
95
- L, Q = nystrom_approximation(A_mv=H_mv, A_mm=H_mm, ndim=ndim, rank=fs['rank'],
96
- dtype=dtype, device=device, generator=generator)
112
+ # compute the approximation
113
+ L, Q = nystrom_approximation(
114
+ A_mv=H_mv,
115
+ A_mm=H_mm,
116
+ ndim=ndim,
117
+ rank=min(fs["rank"], ndim),
118
+ eigv_tol=fs["eigv_tol"],
119
+ orthogonalize_method=fs["orthogonalize_method"],
120
+ dtype=dtype,
121
+ device=device,
122
+ generator=generator,
123
+ )
124
+
125
+ # regularize
126
+ L, Q = regularize_eigh(
127
+ L=L,
128
+ Q=Q,
129
+ truncate=fs["truncate"],
130
+ tol=fs["eigv_tol"],
131
+ damping=fs["damping"],
132
+ rdamping=fs["rdamping"],
133
+ )
134
+
135
+ # store
136
+ if L is not None:
137
+ self.global_state["L"] = L
138
+ self.global_state["Q"] = Q
97
139
 
98
- self.global_state["L"] = L
99
- self.global_state["Q"] = Q
100
- except torch.linalg.LinAlgError:
101
- pass
140
+ except torch.linalg.LinAlgError as e:
141
+ warnings.warn(f"Nystrom approximation failed with: {e}")
102
142
 
103
143
  def apply_states(self, objective, states, settings):
104
- fs = settings[0]
105
- b = objective.get_updates()
106
-
107
- # ----------------------------------- solve ---------------------------------- #
108
144
  if "L" not in self.global_state:
109
145
  return objective
110
146
 
147
+ fs = settings[0]
148
+ updates = objective.get_updates()
149
+ b=torch.cat([t.ravel() for t in updates])
150
+
151
+ # ----------------------------------- solve ---------------------------------- #
111
152
  L = self.global_state["L"]
112
153
  Q = self.global_state["Q"]
113
- x = nystrom_sketch_and_solve(L=L, Q=Q, b=torch.cat([t.ravel() for t in b]), reg=fs["reg"])
154
+
155
+ if fs["reg"] is None:
156
+ x = Q @ ((Q.mH @ b) / L)
157
+ else:
158
+ x = nystrom_sketch_and_solve(L=L, Q=Q, b=b, reg=fs["reg"])
114
159
 
115
160
  # -------------------------------- set update -------------------------------- #
116
161
  objective.updates = vec_to_tensors(x, reference=objective.params)
@@ -127,8 +172,6 @@ class NystromSketchAndSolve(Transform):
127
172
 
128
173
  class NystromPCG(Transform):
129
174
  """Newton's method with a Nyström-preconditioned conjugate gradient solver.
130
- This tends to outperform NewtonCG but requires tuning sketch size.
131
- An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
132
175
 
133
176
  Notes:
134
177
  - This module requires the a closure passed to the optimizer step,
@@ -138,7 +181,7 @@ class NystromPCG(Transform):
138
181
  - In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
139
182
 
140
183
  Args:
141
- sketch_size (int):
184
+ rank (int):
142
185
  size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
143
186
  running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
144
187
  conjugate gradient.
@@ -169,7 +212,7 @@ class NystromPCG(Transform):
169
212
  NystromPCG with backtracking line search
170
213
 
171
214
  ```python
172
- opt = tz.Modular(
215
+ opt = tz.Optimizer(
173
216
  model.parameters(),
174
217
  tz.m.NystromPCG(10),
175
218
  tz.m.Backtracking()
@@ -187,6 +230,8 @@ class NystromPCG(Transform):
187
230
  tol=1e-8,
188
231
  reg: float = 1e-6,
189
232
  update_freq: int = 1, # here update_freq is within update_states
233
+ eigv_tol: float = 0,
234
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
190
235
  hvp_method: HVPMethod = "batched_autograd",
191
236
  h=1e-3,
192
237
  inner: Chainable | None = None,
@@ -202,31 +247,36 @@ class NystromPCG(Transform):
202
247
 
203
248
  # ---------------------- Hessian vector product function --------------------- #
204
249
  # this should run on every update_states
205
- hvp_method = fs['hvp_method']
206
- h = fs['h']
207
- _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
250
+ _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=fs['hvp_method'], h=fs['h'], at_x0=True)
208
251
  objective.temp = H_mv
209
252
 
210
253
  # --------------------------- update preconditioner -------------------------- #
211
254
  step = self.increment_counter("step", 0)
212
- update_freq = self.defaults["update_freq"]
213
-
214
- if step % update_freq == 0:
255
+ if step % fs["update_freq"] == 0:
215
256
 
216
- rank = fs['rank']
217
257
  ndim = sum(t.numel() for t in objective.params)
218
258
  device = objective.params[0].device
219
259
  dtype = objective.params[0].dtype
220
260
  generator = self.get_generator(device, seed=fs['seed'])
221
261
 
222
262
  try:
223
- L, Q = nystrom_approximation(A_mv=None, A_mm=H_mm, ndim=ndim, rank=rank,
224
- dtype=dtype, device=device, generator=generator)
263
+ L, Q = nystrom_approximation(
264
+ A_mv=None,
265
+ A_mm=H_mm,
266
+ ndim=ndim,
267
+ rank=min(fs["rank"], ndim),
268
+ eigv_tol=fs["eigv_tol"],
269
+ orthogonalize_method=fs["orthogonalize_method"],
270
+ dtype=dtype,
271
+ device=device,
272
+ generator=generator,
273
+ )
225
274
 
226
275
  self.global_state["L"] = L
227
276
  self.global_state["Q"] = Q
228
- except torch.linalg.LinAlgError:
229
- pass
277
+
278
+ except torch.linalg.LinAlgError as e:
279
+ warnings.warn(f"Nystrom approximation failed with: {e}")
230
280
 
231
281
  @torch.no_grad
232
282
  def apply_states(self, objective, states, settings):
@@ -243,6 +293,7 @@ class NystromPCG(Transform):
243
293
 
244
294
  L = self.global_state["L"]
245
295
  Q = self.global_state["Q"]
296
+
246
297
  x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
247
298
  reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
248
299