torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,111 @@
1
- from collections.abc import Callable
2
1
  from typing import Literal, overload
3
- import warnings
4
2
  import torch
5
3
 
6
- from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
4
+ from ...utils import TensorList, as_tensorlist, NumberList
7
5
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
6
 
9
- from ...core import Chainable, apply, Module
10
- from ...utils.linalg.solve import cg
7
+ from ...core import Chainable, apply_transform, Module
8
+ from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
11
9
 
12
10
  class NewtonCG(Module):
11
+ """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
12
+
13
+ This optimizer implements Newton's method using a matrix-free conjugate
14
+ gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
15
+ forming the full Hessian matrix, it only requires Hessian-vector products
16
+ (HVPs). These can be calculated efficiently using automatic
17
+ differentiation or approximated using finite differences.
18
+
19
+ .. note::
20
+ In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
21
+
22
+ .. note::
23
+ This module requires the a closure passed to the optimizer step,
24
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
25
+ The closure must accept a ``backward`` argument (refer to documentation).
26
+
27
+ .. warning::
28
+ CG may fail if hessian is not positive-definite.
29
+
30
+ Args:
31
+ maxiter (int | None, optional):
32
+ Maximum number of iterations for the conjugate gradient solver.
33
+ By default, this is set to the number of dimensions in the
34
+ objective function, which is the theoretical upper bound for CG
35
+ convergence. Setting this to a smaller value (truncated Newton)
36
+ can still generate good search directions. Defaults to None.
37
+ tol (float, optional):
38
+ Relative tolerance for the conjugate gradient solver to determine
39
+ convergence. Defaults to 1e-4.
40
+ reg (float, optional):
41
+ Regularization parameter (damping) added to the Hessian diagonal.
42
+ This helps ensure the system is positive-definite. Defaults to 1e-8.
43
+ hvp_method (str, optional):
44
+ Determines how Hessian-vector products are evaluated.
45
+
46
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
47
+ This requires creating a graph for the gradient.
48
+ - ``"forward"``: Use a forward finite difference formula to
49
+ approximate the HVP. This requires one extra gradient evaluation.
50
+ - ``"central"``: Use a central finite difference formula for a
51
+ more accurate HVP approximation. This requires two extra
52
+ gradient evaluations.
53
+ Defaults to "autograd".
54
+ h (float, optional):
55
+ The step size for finite differences if :code:`hvp_method` is
56
+ ``"forward"`` or ``"central"``. Defaults to 1e-3.
57
+ warm_start (bool, optional):
58
+ If ``True``, the conjugate gradient solver is initialized with the
59
+ solution from the previous optimization step. This can accelerate
60
+ convergence, especially in truncated Newton methods.
61
+ Defaults to False.
62
+ inner (Chainable | None, optional):
63
+ NewtonCG will attempt to apply preconditioning to the output of this module.
64
+
65
+ Examples:
66
+ Newton-CG with a backtracking line search:
67
+
68
+ .. code-block:: python
69
+
70
+ opt = tz.Modular(
71
+ model.parameters(),
72
+ tz.m.NewtonCG(),
73
+ tz.m.Backtracking()
74
+ )
75
+
76
+ Truncated Newton method (useful for large-scale problems):
77
+
78
+ .. code-block:: python
79
+
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.NewtonCG(maxiter=10, warm_start=True),
83
+ tz.m.Backtracking()
84
+ )
85
+
86
+
87
+ """
13
88
  def __init__(
14
89
  self,
15
- maxiter=None,
16
- tol=1e-3,
90
+ maxiter: int | None = None,
91
+ tol: float = 1e-4,
17
92
  reg: float = 1e-8,
18
- hvp_method: Literal["forward", "central", "autograd"] = "forward",
19
- h=1e-3,
93
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
94
+ solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
95
+ h: float = 1e-3,
20
96
  warm_start=False,
21
97
  inner: Chainable | None = None,
22
98
  ):
23
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, warm_start=warm_start)
99
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
24
100
  super().__init__(defaults,)
25
101
 
26
102
  if inner is not None:
27
103
  self.set_child('inner', inner)
28
104
 
29
105
  @torch.no_grad
30
- def step(self, vars):
31
- params = TensorList(vars.params)
32
- closure = vars.closure
106
+ def step(self, var):
107
+ params = TensorList(var.params)
108
+ closure = var.closure
33
109
  if closure is None: raise RuntimeError('NewtonCG requires closure')
34
110
 
35
111
  settings = self.settings[params[0]]
@@ -37,12 +113,13 @@ class NewtonCG(Module):
37
113
  reg = settings['reg']
38
114
  maxiter = settings['maxiter']
39
115
  hvp_method = settings['hvp_method']
116
+ solver = settings['solver'].lower().strip()
40
117
  h = settings['h']
41
118
  warm_start = settings['warm_start']
42
119
 
43
120
  # ---------------------- Hessian vector product function --------------------- #
44
121
  if hvp_method == 'autograd':
45
- grad = vars.get_grad(create_graph=True)
122
+ grad = var.get_grad(create_graph=True)
46
123
 
47
124
  def H_mm(x):
48
125
  with torch.enable_grad():
@@ -51,7 +128,7 @@ class NewtonCG(Module):
51
128
  else:
52
129
 
53
130
  with torch.enable_grad():
54
- grad = vars.get_grad()
131
+ grad = var.get_grad()
55
132
 
56
133
  if hvp_method == 'forward':
57
134
  def H_mm(x):
@@ -66,19 +143,232 @@ class NewtonCG(Module):
66
143
 
67
144
 
68
145
  # -------------------------------- inner step -------------------------------- #
69
- b = vars.get_update()
146
+ b = var.get_update()
70
147
  if 'inner' in self.children:
71
- b = as_tensorlist(apply(self.children['inner'], b, params=params, grads=grad, vars=vars))
148
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
149
+ b = as_tensorlist(b)
72
150
 
73
151
  # ---------------------------------- run cg ---------------------------------- #
74
152
  x0 = None
75
- if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
76
- x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
153
+ if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
154
+
155
+ if solver == 'cg':
156
+ x = cg(A_mm=H_mm, b=b, x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
157
+
158
+ elif solver == 'minres':
159
+ x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
160
+
161
+ elif solver == 'minres_npc':
162
+ x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
163
+
164
+ else:
165
+ raise ValueError(f"Unknown solver {solver}")
166
+
77
167
  if warm_start:
78
168
  assert x0 is not None
79
169
  x0.copy_(x)
80
170
 
81
- vars.update = x
82
- return vars
171
+ var.update = x
172
+ return var
173
+
174
+
175
+ class TruncatedNewtonCG(Module):
176
+ """Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
177
+
178
+ This optimizer implements Newton's method using a matrix-free conjugate
179
+ gradient (CG) solver to approximate the search direction. Instead of
180
+ forming the full Hessian matrix, it only requires Hessian-vector products
181
+ (HVPs). These can be calculated efficiently using automatic
182
+ differentiation or approximated using finite differences.
183
+
184
+ .. note::
185
+ In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
186
+
187
+ .. note::
188
+ This module requires the a closure passed to the optimizer step,
189
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
190
+ The closure must accept a ``backward`` argument (refer to documentation).
191
+
192
+ .. warning::
193
+ CG may fail if hessian is not positive-definite.
194
+
195
+ Args:
196
+ maxiter (int | None, optional):
197
+ Maximum number of iterations for the conjugate gradient solver.
198
+ By default, this is set to the number of dimensions in the
199
+ objective function, which is the theoretical upper bound for CG
200
+ convergence. Setting this to a smaller value (truncated Newton)
201
+ can still generate good search directions. Defaults to None.
202
+ eta (float, optional):
203
+ whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
204
+ nplus (float, optional):
205
+ trust region multiplier on successful steps.
206
+ nminus (float, optional):
207
+ trust region multiplier on unsuccessful steps.
208
+ init (float, optional): initial trust region.
209
+ tol (float, optional):
210
+ Relative tolerance for the conjugate gradient solver to determine
211
+ convergence. Defaults to 1e-4.
212
+ reg (float, optional):
213
+ Regularization parameter (damping) added to the Hessian diagonal.
214
+ This helps ensure the system is positive-definite. Defaults to 1e-8.
215
+ hvp_method (str, optional):
216
+ Determines how Hessian-vector products are evaluated.
217
+
218
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
219
+ This requires creating a graph for the gradient.
220
+ - ``"forward"``: Use a forward finite difference formula to
221
+ approximate the HVP. This requires one extra gradient evaluation.
222
+ - ``"central"``: Use a central finite difference formula for a
223
+ more accurate HVP approximation. This requires two extra
224
+ gradient evaluations.
225
+ Defaults to "autograd".
226
+ h (float, optional):
227
+ The step size for finite differences if :code:`hvp_method` is
228
+ ``"forward"`` or ``"central"``. Defaults to 1e-3.
229
+ inner (Chainable | None, optional):
230
+ NewtonCG will attempt to apply preconditioning to the output of this module.
231
+
232
+ Examples:
233
+ Trust-region Newton-CG:
234
+
235
+ .. code-block:: python
236
+
237
+ opt = tz.Modular(
238
+ model.parameters(),
239
+ tz.m.NewtonCGSteihaug(),
240
+ )
241
+
242
+ Reference:
243
+ Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
244
+ """
245
+ def __init__(
246
+ self,
247
+ maxiter: int | None = None,
248
+ eta: float= 1e-6,
249
+ nplus: float = 2,
250
+ nminus: float = 0.25,
251
+ init: float = 1,
252
+ tol: float = 1e-4,
253
+ reg: float = 1e-8,
254
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
255
+ solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
256
+ h: float = 1e-3,
257
+ max_attempts: int = 10,
258
+ inner: Chainable | None = None,
259
+ ):
260
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, eta=eta, nplus=nplus, nminus=nminus, init=init, max_attempts=max_attempts, solver=solver)
261
+ super().__init__(defaults,)
262
+
263
+ if inner is not None:
264
+ self.set_child('inner', inner)
265
+
266
+ @torch.no_grad
267
+ def step(self, var):
268
+ params = TensorList(var.params)
269
+ closure = var.closure
270
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
271
+
272
+ settings = self.settings[params[0]]
273
+ tol = settings['tol']
274
+ reg = settings['reg']
275
+ maxiter = settings['maxiter']
276
+ hvp_method = settings['hvp_method']
277
+ h = settings['h']
278
+ max_attempts = settings['max_attempts']
279
+ solver = settings['solver'].lower().strip()
280
+
281
+ eta = settings['eta']
282
+ nplus = settings['nplus']
283
+ nminus = settings['nminus']
284
+ init = settings['init']
285
+
286
+ # ---------------------- Hessian vector product function --------------------- #
287
+ if hvp_method == 'autograd':
288
+ grad = var.get_grad(create_graph=True)
289
+
290
+ def H_mm(x):
291
+ with torch.enable_grad():
292
+ return TensorList(hvp(params, grad, x, retain_graph=True))
293
+
294
+ else:
295
+
296
+ with torch.enable_grad():
297
+ grad = var.get_grad()
298
+
299
+ if hvp_method == 'forward':
300
+ def H_mm(x):
301
+ return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
302
+
303
+ elif hvp_method == 'central':
304
+ def H_mm(x):
305
+ return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
306
+
307
+ else:
308
+ raise ValueError(hvp_method)
309
+
310
+
311
+ # -------------------------------- inner step -------------------------------- #
312
+ b = var.get_update()
313
+ if 'inner' in self.children:
314
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
315
+ b = as_tensorlist(b)
316
+
317
+ # ---------------------------------- run cg ---------------------------------- #
318
+ success = False
319
+ x = None
320
+ while not success:
321
+ max_attempts -= 1
322
+ if max_attempts < 0: break
323
+
324
+ trust_region = self.global_state.get('trust_region', init)
325
+ if trust_region < 1e-8 or trust_region > 1e8:
326
+ trust_region = self.global_state['trust_region'] = init
327
+
328
+ if solver == 'cg':
329
+ x = steihaug_toint_cg(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg)
330
+
331
+ elif solver == 'minres':
332
+ x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
333
+
334
+ elif solver == 'minres_npc':
335
+ x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
336
+
337
+ else:
338
+ raise ValueError(f"unknown solver {solver}")
339
+
340
+ # ------------------------------- trust region ------------------------------- #
341
+ Hx = H_mm(x)
342
+ pred_reduction = b.dot(x) - 0.5 * x.dot(Hx)
343
+
344
+ params -= x
345
+ loss_star = closure(False)
346
+ params += x
347
+ reduction = var.get_loss(False) - loss_star
348
+
349
+ rho = reduction / (pred_reduction.clip(min=1e-8))
350
+
351
+ # failed step
352
+ if rho < 0.25:
353
+ self.global_state['trust_region'] = trust_region * nminus
354
+
355
+ # very good step
356
+ elif rho > 0.75:
357
+ diff = trust_region - x.abs()
358
+ if (diff.global_min() / trust_region) > 1e-4: # hits boundary
359
+ self.global_state['trust_region'] = trust_region * nplus
360
+
361
+ # if the ratio is high enough then accept the proposed step
362
+ if rho > eta:
363
+ success = True
364
+
365
+ assert x is not None
366
+ if success:
367
+ var.update = x
368
+
369
+ else:
370
+ var.update = params.zeros_like()
371
+
372
+ return var
83
373
 
84
374
 
@@ -6,16 +6,64 @@ import torch
6
6
  from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
7
7
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
8
 
9
- from ...core import Chainable, apply, Module
9
+ from ...core import Chainable, apply_transform, Module
10
10
  from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
11
11
 
12
12
  class NystromSketchAndSolve(Module):
13
+ """Newton's method with a Nyström sketch-and-solve solver.
14
+
15
+ .. note::
16
+ This module requires the a closure passed to the optimizer step,
17
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
18
+ The closure must accept a ``backward`` argument (refer to documentation).
19
+
20
+ .. note::
21
+ In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
22
+
23
+ .. note::
24
+ If this is unstable, increase the :code:`reg` parameter and tune the rank.
25
+
26
+ .. note:
27
+ :code:`tz.m.NystromPCG` usually outperforms this.
28
+
29
+ Args:
30
+ rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
31
+ reg (float, optional): regularization parameter. Defaults to 1e-3.
32
+ hvp_method (str, optional):
33
+ Determines how Hessian-vector products are evaluated.
34
+
35
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
+ This requires creating a graph for the gradient.
37
+ - ``"forward"``: Use a forward finite difference formula to
38
+ approximate the HVP. This requires one extra gradient evaluation.
39
+ - ``"central"``: Use a central finite difference formula for a
40
+ more accurate HVP approximation. This requires two extra
41
+ gradient evaluations.
42
+ Defaults to "autograd".
43
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
44
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
45
+ seed (int | None, optional): seed for random generator. Defaults to None.
46
+
47
+ Examples:
48
+ NystromSketchAndSolve with backtracking line search
49
+
50
+ .. code-block:: python
51
+
52
+ opt = tz.Modular(
53
+ model.parameters(),
54
+ tz.m.NystromSketchAndSolve(10),
55
+ tz.m.Backtracking()
56
+ )
57
+
58
+ Reference:
59
+ Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
60
+ """
13
61
  def __init__(
14
62
  self,
15
63
  rank: int,
16
64
  reg: float = 1e-3,
17
65
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
- h=1e-2,
66
+ h: float = 1e-3,
19
67
  inner: Chainable | None = None,
20
68
  seed: int | None = None,
21
69
  ):
@@ -26,10 +74,10 @@ class NystromSketchAndSolve(Module):
26
74
  self.set_child('inner', inner)
27
75
 
28
76
  @torch.no_grad
29
- def step(self, vars):
30
- params = TensorList(vars.params)
77
+ def step(self, var):
78
+ params = TensorList(var.params)
31
79
 
32
- closure = vars.closure
80
+ closure = var.closure
33
81
  if closure is None: raise RuntimeError('NewtonCG requires closure')
34
82
 
35
83
  settings = self.settings[params[0]]
@@ -47,7 +95,7 @@ class NystromSketchAndSolve(Module):
47
95
 
48
96
  # ---------------------- Hessian vector product function --------------------- #
49
97
  if hvp_method == 'autograd':
50
- grad = vars.get_grad(create_graph=True)
98
+ grad = var.get_grad(create_graph=True)
51
99
 
52
100
  def H_mm(x):
53
101
  with torch.enable_grad():
@@ -57,7 +105,7 @@ class NystromSketchAndSolve(Module):
57
105
  else:
58
106
 
59
107
  with torch.enable_grad():
60
- grad = vars.get_grad()
108
+ grad = var.get_grad()
61
109
 
62
110
  if hvp_method == 'forward':
63
111
  def H_mm(x):
@@ -74,18 +122,73 @@ class NystromSketchAndSolve(Module):
74
122
 
75
123
 
76
124
  # -------------------------------- inner step -------------------------------- #
77
- b = vars.get_update()
125
+ b = var.get_update()
78
126
  if 'inner' in self.children:
79
- b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
127
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
80
128
 
81
129
  # ------------------------------ sketch&n&solve ------------------------------ #
82
130
  x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
83
- vars.update = vec_to_tensors(x, reference=params)
84
- return vars
131
+ var.update = vec_to_tensors(x, reference=params)
132
+ return var
85
133
 
86
134
 
87
135
 
88
136
  class NystromPCG(Module):
137
+ """Newton's method with a Nyström-preconditioned conjugate gradient solver.
138
+ This tends to outperform NewtonCG but requires tuning sketch size.
139
+ An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
140
+
141
+ .. note::
142
+ This module requires the a closure passed to the optimizer step,
143
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
144
+ The closure must accept a ``backward`` argument (refer to documentation).
145
+
146
+ .. note::
147
+ In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
148
+
149
+ Args:
150
+ sketch_size (int):
151
+ size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
152
+ running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
153
+ conjugate gradient.
154
+ maxiter (int | None, optional):
155
+ maximum number of iterations. By default this is set to the number of dimensions
156
+ in the objective function, which is supposed to be enough for conjugate gradient
157
+ to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
158
+ Defaults to None.
159
+ tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
160
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
161
+ hvp_method (str, optional):
162
+ Determines how Hessian-vector products are evaluated.
163
+
164
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
165
+ This requires creating a graph for the gradient.
166
+ - ``"forward"``: Use a forward finite difference formula to
167
+ approximate the HVP. This requires one extra gradient evaluation.
168
+ - ``"central"``: Use a central finite difference formula for a
169
+ more accurate HVP approximation. This requires two extra
170
+ gradient evaluations.
171
+ Defaults to "autograd".
172
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
173
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
174
+ seed (int | None, optional): seed for random generator. Defaults to None.
175
+
176
+ Examples:
177
+
178
+ NystromPCG with backtracking line search
179
+
180
+ .. code-block:: python
181
+
182
+ opt = tz.Modular(
183
+ model.parameters(),
184
+ tz.m.NystromPCG(10),
185
+ tz.m.Backtracking()
186
+ )
187
+
188
+ Reference:
189
+ Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
190
+
191
+ """
89
192
  def __init__(
90
193
  self,
91
194
  sketch_size: int,
@@ -93,7 +196,7 @@ class NystromPCG(Module):
93
196
  tol=1e-3,
94
197
  reg: float = 1e-6,
95
198
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
96
- h=1e-2,
199
+ h=1e-3,
97
200
  inner: Chainable | None = None,
98
201
  seed: int | None = None,
99
202
  ):
@@ -104,10 +207,10 @@ class NystromPCG(Module):
104
207
  self.set_child('inner', inner)
105
208
 
106
209
  @torch.no_grad
107
- def step(self, vars):
108
- params = TensorList(vars.params)
210
+ def step(self, var):
211
+ params = TensorList(var.params)
109
212
 
110
- closure = vars.closure
213
+ closure = var.closure
111
214
  if closure is None: raise RuntimeError('NewtonCG requires closure')
112
215
 
113
216
  settings = self.settings[params[0]]
@@ -129,7 +232,7 @@ class NystromPCG(Module):
129
232
 
130
233
  # ---------------------- Hessian vector product function --------------------- #
131
234
  if hvp_method == 'autograd':
132
- grad = vars.get_grad(create_graph=True)
235
+ grad = var.get_grad(create_graph=True)
133
236
 
134
237
  def H_mm(x):
135
238
  with torch.enable_grad():
@@ -139,7 +242,7 @@ class NystromPCG(Module):
139
242
  else:
140
243
 
141
244
  with torch.enable_grad():
142
- grad = vars.get_grad()
245
+ grad = var.get_grad()
143
246
 
144
247
  if hvp_method == 'forward':
145
248
  def H_mm(x):
@@ -156,13 +259,13 @@ class NystromPCG(Module):
156
259
 
157
260
 
158
261
  # -------------------------------- inner step -------------------------------- #
159
- b = vars.get_update()
262
+ b = var.get_update()
160
263
  if 'inner' in self.children:
161
- b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
264
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
162
265
 
163
266
  # ------------------------------ sketch&n&solve ------------------------------ #
164
267
  x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
165
- vars.update = vec_to_tensors(x, reference=params)
166
- return vars
268
+ var.update = vec_to_tensors(x, reference=params)
269
+ return var
167
270
 
168
271