torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,31 +1,115 @@
1
- from collections.abc import Callable
2
- from typing import Literal, overload
3
1
  import warnings
2
+ import math
3
+ from typing import Literal, cast
4
+ from operator import itemgetter
4
5
  import torch
5
6
 
6
- from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
7
+ from ...core import Chainable, Module, apply_transform
8
+ from ...utils import TensorList, as_tensorlist, tofloat
7
9
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
-
9
- from ...core import Chainable, apply_transform, Module
10
- from ...utils.linalg.solve import cg
10
+ from ...utils.linalg.solve import cg, minres, find_within_trust_radius
11
+ from ..trust_region.trust_region import default_radius
11
12
 
12
13
  class NewtonCG(Module):
14
+ """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
15
+
16
+ This optimizer implements Newton's method using a matrix-free conjugate
17
+ gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
18
+ forming the full Hessian matrix, it only requires Hessian-vector products
19
+ (HVPs). These can be calculated efficiently using automatic
20
+ differentiation or approximated using finite differences.
21
+
22
+ .. note::
23
+ In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
24
+
25
+ .. note::
26
+ This module requires the a closure passed to the optimizer step,
27
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
28
+ The closure must accept a ``backward`` argument (refer to documentation).
29
+
30
+ .. warning::
31
+ CG may fail if hessian is not positive-definite.
32
+
33
+ Args:
34
+ maxiter (int | None, optional):
35
+ Maximum number of iterations for the conjugate gradient solver.
36
+ By default, this is set to the number of dimensions in the
37
+ objective function, which is the theoretical upper bound for CG
38
+ convergence. Setting this to a smaller value (truncated Newton)
39
+ can still generate good search directions. Defaults to None.
40
+ tol (float, optional):
41
+ Relative tolerance for the conjugate gradient solver to determine
42
+ convergence. Defaults to 1e-4.
43
+ reg (float, optional):
44
+ Regularization parameter (damping) added to the Hessian diagonal.
45
+ This helps ensure the system is positive-definite. Defaults to 1e-8.
46
+ hvp_method (str, optional):
47
+ Determines how Hessian-vector products are evaluated.
48
+
49
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
50
+ This requires creating a graph for the gradient.
51
+ - ``"forward"``: Use a forward finite difference formula to
52
+ approximate the HVP. This requires one extra gradient evaluation.
53
+ - ``"central"``: Use a central finite difference formula for a
54
+ more accurate HVP approximation. This requires two extra
55
+ gradient evaluations.
56
+ Defaults to "autograd".
57
+ h (float, optional):
58
+ The step size for finite differences if :code:`hvp_method` is
59
+ ``"forward"`` or ``"central"``. Defaults to 1e-3.
60
+ warm_start (bool, optional):
61
+ If ``True``, the conjugate gradient solver is initialized with the
62
+ solution from the previous optimization step. This can accelerate
63
+ convergence, especially in truncated Newton methods.
64
+ Defaults to False.
65
+ inner (Chainable | None, optional):
66
+ NewtonCG will attempt to apply preconditioning to the output of this module.
67
+
68
+ Examples:
69
+ Newton-CG with a backtracking line search:
70
+
71
+ .. code-block:: python
72
+
73
+ opt = tz.Modular(
74
+ model.parameters(),
75
+ tz.m.NewtonCG(),
76
+ tz.m.Backtracking()
77
+ )
78
+
79
+ Truncated Newton method (useful for large-scale problems):
80
+
81
+ .. code-block:: python
82
+
83
+ opt = tz.Modular(
84
+ model.parameters(),
85
+ tz.m.NewtonCG(maxiter=10, warm_start=True),
86
+ tz.m.Backtracking()
87
+ )
88
+
89
+
90
+ """
13
91
  def __init__(
14
92
  self,
15
- maxiter=None,
16
- tol=1e-4,
93
+ maxiter: int | None = None,
94
+ tol: float = 1e-8,
17
95
  reg: float = 1e-8,
18
- hvp_method: Literal["forward", "central", "autograd"] = "forward",
19
- h=1e-3,
96
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
97
+ solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
98
+ h: float = 1e-3,
99
+ miniter:int = 1,
20
100
  warm_start=False,
21
101
  inner: Chainable | None = None,
22
102
  ):
23
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, warm_start=warm_start)
103
+ defaults = locals().copy()
104
+ del defaults['self'], defaults['inner']
24
105
  super().__init__(defaults,)
25
106
 
26
107
  if inner is not None:
27
108
  self.set_child('inner', inner)
28
109
 
110
+ self._num_hvps = 0
111
+ self._num_hvps_last_step = 0
112
+
29
113
  @torch.no_grad
30
114
  def step(self, var):
31
115
  params = TensorList(var.params)
@@ -37,14 +121,17 @@ class NewtonCG(Module):
37
121
  reg = settings['reg']
38
122
  maxiter = settings['maxiter']
39
123
  hvp_method = settings['hvp_method']
124
+ solver = settings['solver'].lower().strip()
40
125
  h = settings['h']
41
126
  warm_start = settings['warm_start']
42
127
 
128
+ self._num_hvps_last_step = 0
43
129
  # ---------------------- Hessian vector product function --------------------- #
44
130
  if hvp_method == 'autograd':
45
131
  grad = var.get_grad(create_graph=True)
46
132
 
47
133
  def H_mm(x):
134
+ self._num_hvps_last_step += 1
48
135
  with torch.enable_grad():
49
136
  return TensorList(hvp(params, grad, x, retain_graph=True))
50
137
 
@@ -55,10 +142,12 @@ class NewtonCG(Module):
55
142
 
56
143
  if hvp_method == 'forward':
57
144
  def H_mm(x):
145
+ self._num_hvps_last_step += 1
58
146
  return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
59
147
 
60
148
  elif hvp_method == 'central':
61
149
  def H_mm(x):
150
+ self._num_hvps_last_step += 1
62
151
  return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
63
152
 
64
153
  else:
@@ -68,18 +157,279 @@ class NewtonCG(Module):
68
157
  # -------------------------------- inner step -------------------------------- #
69
158
  b = var.get_update()
70
159
  if 'inner' in self.children:
71
- b = as_tensorlist(apply_transform(self.children['inner'], b, params=params, grads=grad, var=var))
160
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
161
+ b = as_tensorlist(b)
72
162
 
73
163
  # ---------------------------------- run cg ---------------------------------- #
74
164
  x0 = None
75
165
  if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
76
166
 
77
- x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
167
+ if solver == 'cg':
168
+ d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)
169
+
170
+ elif solver == 'minres':
171
+ d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
172
+
173
+ elif solver == 'minres_npc':
174
+ d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
175
+
176
+ else:
177
+ raise ValueError(f"Unknown solver {solver}")
178
+
78
179
  if warm_start:
79
180
  assert x0 is not None
80
- x0.copy_(x)
181
+ x0.copy_(d)
81
182
 
82
- var.update = x
183
+ var.update = d
184
+
185
+ self._num_hvps += self._num_hvps_last_step
83
186
  return var
84
187
 
85
188
 
189
+ class NewtonCGSteihaug(Module):
190
+ """Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
191
+
192
+ This optimizer implements Newton's method using a matrix-free conjugate
193
+ gradient (CG) solver to approximate the search direction. Instead of
194
+ forming the full Hessian matrix, it only requires Hessian-vector products
195
+ (HVPs). These can be calculated efficiently using automatic
196
+ differentiation or approximated using finite differences.
197
+
198
+ .. note::
199
+ In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
200
+
201
+ .. note::
202
+ This module requires the a closure passed to the optimizer step,
203
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
204
+ The closure must accept a ``backward`` argument (refer to documentation).
205
+
206
+ .. warning::
207
+ CG may fail if hessian is not positive-definite.
208
+
209
+ Args:
210
+ maxiter (int | None, optional):
211
+ Maximum number of iterations for the conjugate gradient solver.
212
+ By default, this is set to the number of dimensions in the
213
+ objective function, which is the theoretical upper bound for CG
214
+ convergence. Setting this to a smaller value (truncated Newton)
215
+ can still generate good search directions. Defaults to None.
216
+ eta (float, optional):
217
+ whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
218
+ nplus (float, optional):
219
+ trust region multiplier on successful steps.
220
+ nminus (float, optional):
221
+ trust region multiplier on unsuccessful steps.
222
+ init (float, optional): initial trust region.
223
+ tol (float, optional):
224
+ Relative tolerance for the conjugate gradient solver to determine
225
+ convergence. Defaults to 1e-4.
226
+ reg (float, optional):
227
+ Regularization parameter (damping) added to the Hessian diagonal.
228
+ This helps ensure the system is positive-definite. Defaults to 1e-8.
229
+ hvp_method (str, optional):
230
+ Determines how Hessian-vector products are evaluated.
231
+
232
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
233
+ This requires creating a graph for the gradient.
234
+ - ``"forward"``: Use a forward finite difference formula to
235
+ approximate the HVP. This requires one extra gradient evaluation.
236
+ - ``"central"``: Use a central finite difference formula for a
237
+ more accurate HVP approximation. This requires two extra
238
+ gradient evaluations.
239
+ Defaults to "autograd".
240
+ h (float, optional):
241
+ The step size for finite differences if :code:`hvp_method` is
242
+ ``"forward"`` or ``"central"``. Defaults to 1e-3.
243
+ inner (Chainable | None, optional):
244
+ NewtonCG will attempt to apply preconditioning to the output of this module.
245
+
246
+ Examples:
247
+ Trust-region Newton-CG:
248
+
249
+ .. code-block:: python
250
+
251
+ opt = tz.Modular(
252
+ model.parameters(),
253
+ tz.m.NewtonCGSteihaug(),
254
+ )
255
+
256
+ Reference:
257
+ Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
258
+ """
259
+ def __init__(
260
+ self,
261
+ maxiter: int | None = None,
262
+ eta: float= 0.0,
263
+ nplus: float = 3.5,
264
+ nminus: float = 0.25,
265
+ rho_good: float = 0.99,
266
+ rho_bad: float = 1e-4,
267
+ init: float = 1,
268
+ tol: float = 1e-8,
269
+ reg: float = 1e-8,
270
+ hvp_method: Literal["forward", "central"] = "forward",
271
+ solver: Literal['cg', "minres"] = 'cg',
272
+ h: float = 1e-3,
273
+ max_attempts: int = 100,
274
+ max_history: int = 100,
275
+ boundary_tol: float = 1e-1,
276
+ miniter: int = 1,
277
+ rms_beta: float | None = None,
278
+ adapt_tol: bool = True,
279
+ npc_terminate: bool = False,
280
+ inner: Chainable | None = None,
281
+ ):
282
+ defaults = locals().copy()
283
+ del defaults['self'], defaults['inner']
284
+ super().__init__(defaults,)
285
+
286
+ if inner is not None:
287
+ self.set_child('inner', inner)
288
+
289
+ self._num_hvps = 0
290
+ self._num_hvps_last_step = 0
291
+
292
+ @torch.no_grad
293
+ def step(self, var):
294
+ params = TensorList(var.params)
295
+ closure = var.closure
296
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
297
+
298
+ tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
299
+ solver = self.defaults['solver'].lower().strip()
300
+
301
+ (reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
302
+ eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
303
+ miniter, max_history, adapt_tol) = itemgetter(
304
+ "reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
305
+ "eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
306
+ "miniter", "max_history", "adapt_tol",
307
+ )(self.defaults)
308
+
309
+ self._num_hvps_last_step = 0
310
+
311
+ # ---------------------- Hessian vector product function --------------------- #
312
+ if hvp_method == 'autograd':
313
+ grad = var.get_grad(create_graph=True)
314
+
315
+ def H_mm(x):
316
+ self._num_hvps_last_step += 1
317
+ with torch.enable_grad():
318
+ return TensorList(hvp(params, grad, x, retain_graph=True))
319
+
320
+ else:
321
+
322
+ with torch.enable_grad():
323
+ grad = var.get_grad()
324
+
325
+ if hvp_method == 'forward':
326
+ def H_mm(x):
327
+ self._num_hvps_last_step += 1
328
+ return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
329
+
330
+ elif hvp_method == 'central':
331
+ def H_mm(x):
332
+ self._num_hvps_last_step += 1
333
+ return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
334
+
335
+ else:
336
+ raise ValueError(hvp_method)
337
+
338
+
339
+ # ------------------------- update RMS preconditioner ------------------------ #
340
+ b = var.get_update()
341
+ P_mm = None
342
+ rms_beta = self.defaults["rms_beta"]
343
+ if rms_beta is not None:
344
+ exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
345
+ exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
346
+ exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
347
+ def _P_mm(x):
348
+ return x / exp_avg_sq_sqrt
349
+ P_mm = _P_mm
350
+
351
+ # -------------------------------- inner step -------------------------------- #
352
+ if 'inner' in self.children:
353
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
354
+ b = as_tensorlist(b)
355
+
356
+ # ------------------------------- trust region ------------------------------- #
357
+ success = False
358
+ d = None
359
+ x0 = [p.clone() for p in params]
360
+ solution = None
361
+
362
+ while not success:
363
+ max_attempts -= 1
364
+ if max_attempts < 0: break
365
+
366
+ trust_radius = self.global_state.get('trust_radius', init)
367
+
368
+ # -------------- make sure trust radius isn't too small or large ------------- #
369
+ finfo = torch.finfo(x0[0].dtype)
370
+ if trust_radius < finfo.tiny * 2:
371
+ trust_radius = self.global_state['trust_radius'] = init
372
+ if adapt_tol:
373
+ self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
374
+
375
+ elif trust_radius > finfo.max / 2:
376
+ trust_radius = self.global_state['trust_radius'] = init
377
+
378
+ # ----------------------------------- solve ---------------------------------- #
379
+ d = None
380
+ if solution is not None and solution.history is not None:
381
+ d = find_within_trust_radius(solution.history, trust_radius)
382
+
383
+ if d is None:
384
+ if solver == 'cg':
385
+ d, solution = cg(
386
+ A_mm=H_mm,
387
+ b=b,
388
+ tol=tol,
389
+ maxiter=maxiter,
390
+ reg=reg,
391
+ trust_radius=trust_radius,
392
+ miniter=miniter,
393
+ npc_terminate=npc_terminate,
394
+ history_size=max_history,
395
+ P_mm=P_mm,
396
+ )
397
+
398
+ elif solver == 'minres':
399
+ d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
400
+
401
+ else:
402
+ raise ValueError(f"unknown solver {solver}")
403
+
404
+ # ---------------------------- update trust radius --------------------------- #
405
+ self.global_state["trust_radius"], success = default_radius(
406
+ params=params,
407
+ closure=closure,
408
+ f=tofloat(var.get_loss(False)),
409
+ g=b,
410
+ H=H_mm,
411
+ d=d,
412
+ trust_radius=trust_radius,
413
+ eta=eta,
414
+ nplus=nplus,
415
+ nminus=nminus,
416
+ rho_good=rho_good,
417
+ rho_bad=rho_bad,
418
+ boundary_tol=boundary_tol,
419
+
420
+ init=init, # init isn't used because check_overflow=False
421
+ state=self.global_state, # not used
422
+ settings=self.defaults, # not used
423
+ check_overflow=False, # this is checked manually to adapt tolerance
424
+ )
425
+
426
+ # --------------------------- assign new direction --------------------------- #
427
+ assert d is not None
428
+ if success:
429
+ var.update = d
430
+
431
+ else:
432
+ var.update = params.zeros_like()
433
+
434
+ self._num_hvps += self._num_hvps_last_step
435
+ return var
@@ -10,12 +10,60 @@ from ...core import Chainable, apply_transform, Module
10
10
  from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
11
11
 
12
12
  class NystromSketchAndSolve(Module):
13
+ """Newton's method with a Nyström sketch-and-solve solver.
14
+
15
+ .. note::
16
+ This module requires the a closure passed to the optimizer step,
17
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
18
+ The closure must accept a ``backward`` argument (refer to documentation).
19
+
20
+ .. note::
21
+ In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
22
+
23
+ .. note::
24
+ If this is unstable, increase the :code:`reg` parameter and tune the rank.
25
+
26
+ .. note:
27
+ :code:`tz.m.NystromPCG` usually outperforms this.
28
+
29
+ Args:
30
+ rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
31
+ reg (float, optional): regularization parameter. Defaults to 1e-3.
32
+ hvp_method (str, optional):
33
+ Determines how Hessian-vector products are evaluated.
34
+
35
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
+ This requires creating a graph for the gradient.
37
+ - ``"forward"``: Use a forward finite difference formula to
38
+ approximate the HVP. This requires one extra gradient evaluation.
39
+ - ``"central"``: Use a central finite difference formula for a
40
+ more accurate HVP approximation. This requires two extra
41
+ gradient evaluations.
42
+ Defaults to "autograd".
43
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
44
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
45
+ seed (int | None, optional): seed for random generator. Defaults to None.
46
+
47
+ Examples:
48
+ NystromSketchAndSolve with backtracking line search
49
+
50
+ .. code-block:: python
51
+
52
+ opt = tz.Modular(
53
+ model.parameters(),
54
+ tz.m.NystromSketchAndSolve(10),
55
+ tz.m.Backtracking()
56
+ )
57
+
58
+ Reference:
59
+ Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
60
+ """
13
61
  def __init__(
14
62
  self,
15
63
  rank: int,
16
64
  reg: float = 1e-3,
17
65
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
- h=1e-3,
66
+ h: float = 1e-3,
19
67
  inner: Chainable | None = None,
20
68
  seed: int | None = None,
21
69
  ):
@@ -86,6 +134,61 @@ class NystromSketchAndSolve(Module):
86
134
 
87
135
 
88
136
  class NystromPCG(Module):
137
+ """Newton's method with a Nyström-preconditioned conjugate gradient solver.
138
+ This tends to outperform NewtonCG but requires tuning sketch size.
139
+ An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
140
+
141
+ .. note::
142
+ This module requires the a closure passed to the optimizer step,
143
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
144
+ The closure must accept a ``backward`` argument (refer to documentation).
145
+
146
+ .. note::
147
+ In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
148
+
149
+ Args:
150
+ sketch_size (int):
151
+ size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
152
+ running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
153
+ conjugate gradient.
154
+ maxiter (int | None, optional):
155
+ maximum number of iterations. By default this is set to the number of dimensions
156
+ in the objective function, which is supposed to be enough for conjugate gradient
157
+ to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
158
+ Defaults to None.
159
+ tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
160
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
161
+ hvp_method (str, optional):
162
+ Determines how Hessian-vector products are evaluated.
163
+
164
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
165
+ This requires creating a graph for the gradient.
166
+ - ``"forward"``: Use a forward finite difference formula to
167
+ approximate the HVP. This requires one extra gradient evaluation.
168
+ - ``"central"``: Use a central finite difference formula for a
169
+ more accurate HVP approximation. This requires two extra
170
+ gradient evaluations.
171
+ Defaults to "autograd".
172
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
173
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
174
+ seed (int | None, optional): seed for random generator. Defaults to None.
175
+
176
+ Examples:
177
+
178
+ NystromPCG with backtracking line search
179
+
180
+ .. code-block:: python
181
+
182
+ opt = tz.Modular(
183
+ model.parameters(),
184
+ tz.m.NystromPCG(10),
185
+ tz.m.Backtracking()
186
+ )
187
+
188
+ Reference:
189
+ Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
190
+
191
+ """
89
192
  def __init__(
90
193
  self,
91
194
  sketch_size: int,
@@ -1,2 +1,2 @@
1
1
  from .laplacian import LaplacianSmoothing
2
- from .gaussian import GaussianHomotopy
2
+ from .sampling import GradientSampling
@@ -56,7 +56,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
56
56
  return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
57
 
58
58
  class LaplacianSmoothing(Transform):
59
- """Applies laplacian smoothing via a fast Fourier transform solver.
59
+ """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
60
60
 
61
61
  Args:
62
62
  sigma (float, optional): controls the amount of smoothing. Defaults to 1.
@@ -69,9 +69,19 @@ class LaplacianSmoothing(Transform):
69
69
  target (str, optional):
70
70
  what to set on var.
71
71
 
72
+ Examples:
73
+ Laplacian Smoothing Gradient Descent optimizer as in the paper
74
+
75
+ .. code-block:: python
76
+
77
+ opt = tz.Modular(
78
+ model.parameters(),
79
+ tz.m.LaplacianSmoothing(),
80
+ tz.m.LR(1e-2),
81
+ )
82
+
72
83
  Reference:
73
- *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
74
- Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
84
+ Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
75
85
 
76
86
  """
77
87
  def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
@@ -82,7 +92,7 @@ class LaplacianSmoothing(Transform):
82
92
 
83
93
 
84
94
  @torch.no_grad
85
- def apply(self, tensors, params, grads, loss, states, settings):
95
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
86
96
  layerwise = settings[0]['layerwise']
87
97
 
88
98
  # layerwise laplacian smoothing