torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -44,8 +44,8 @@ def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_ne
44
44
  # use eigvec or -eigvec depending on if it points in same direction as gradient
45
45
  return g.dot(d).sign() * d
46
46
 
47
- L.reciprocal_()
48
- return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
47
+ return Q @ ((Q.mH @ g) / L)
48
+
49
49
  except torch.linalg.LinAlgError:
50
50
  return None
51
51
 
@@ -53,46 +53,109 @@ def tikhonov_(H: torch.Tensor, reg: float):
53
53
  if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
54
54
  return H
55
55
 
56
- def eig_tikhonov_(H: torch.Tensor, reg: float):
57
- v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
58
- return tikhonov_(H, v)
59
-
60
56
 
61
57
  class Newton(Module):
62
- """Exact newton via autograd.
58
+ """Exact newton's method via autograd.
59
+
60
+ .. note::
61
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
62
+
63
+ .. note::
64
+ This module requires the a closure passed to the optimizer step,
65
+ as it needs to re-evaluate the loss and gradients for calculating the hessian.
66
+ The closure must accept a ``backward`` argument (refer to documentation).
67
+
68
+ .. warning::
69
+ this uses roughly O(N^2) memory.
70
+
63
71
 
64
72
  Args:
65
73
  reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
66
- eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
67
74
  search_negative (bool, Optional):
68
- if True, whenever a negative eigenvalue is detected, the direction is taken along an eigenvector corresponding to a negative eigenvalue.
75
+ if True, whenever a negative eigenvalue is detected,
76
+ search direction is proposed along an eigenvector corresponding to a negative eigenvalue.
69
77
  hessian_method (str):
70
78
  how to calculate hessian. Defaults to "autograd".
71
79
  vectorize (bool, optional):
72
80
  whether to enable vectorized hessian. Defaults to True.
73
- inner (Chainable | None, optional): inner modules. Defaults to None.
81
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
74
82
  H_tfm (Callable | None, optional):
75
83
  optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
76
84
 
77
- must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
78
- which must be True if transform inverted the hessian and False otherwise. Defaults to None.
85
+ must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
86
+ which must be True if transform inverted the hessian and False otherwise.
87
+
88
+ Or it returns a single tensor which is used as the update.
89
+
90
+ Defaults to None.
79
91
  eigval_tfm (Callable | None, optional):
80
92
  optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
81
- If this is specified, eigendecomposition will be used to solve Hx = g.
93
+ If this is specified, eigendecomposition will be used to invert the hessian.
94
+
95
+ Examples:
96
+ Newton's method with backtracking line search
97
+
98
+ .. code-block:: python
99
+
100
+ opt = tz.Modular(
101
+ model.parameters(),
102
+ tz.m.Newton(),
103
+ tz.m.Backtracking()
104
+ )
105
+
106
+ Newton's method modified for non-convex functions by taking matrix absolute value of the hessian
107
+
108
+ .. code-block:: python
109
+
110
+ opt = tz.Modular(
111
+ model.parameters(),
112
+ tz.m.Newton(eigval_tfm=lambda x: torch.abs(x).clip(min=0.1)),
113
+ tz.m.Backtracking()
114
+ )
115
+
116
+ Newton's method modified for non-convex functions by searching along negative curvature directions
117
+
118
+ .. code-block:: python
119
+
120
+ opt = tz.Modular(
121
+ model.parameters(),
122
+ tz.m.Newton(search_negative=True),
123
+ tz.m.Backtracking()
124
+ )
125
+
126
+ Newton preconditioning applied to momentum
127
+
128
+ .. code-block:: python
129
+
130
+ opt = tz.Modular(
131
+ model.parameters(),
132
+ tz.m.Newton(inner=tz.m.EMA(0.9)),
133
+ tz.m.LR(0.1)
134
+ )
135
+
136
+ Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient, but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
137
+
138
+ .. code-block:: python
139
+
140
+ opt = tz.Modular(
141
+ model.parameters(),
142
+ tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
143
+ tz.m.Backtracking()
144
+ )
82
145
 
83
146
  """
84
147
  def __init__(
85
148
  self,
86
149
  reg: float = 1e-6,
87
- eig_reg: bool = False,
88
150
  search_negative: bool = False,
151
+ update_freq: int = 1,
89
152
  hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
90
153
  vectorize: bool = True,
91
154
  inner: Chainable | None = None,
92
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
155
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
93
156
  eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
94
157
  ):
95
- defaults = dict(reg=reg, eig_reg=eig_reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative)
158
+ defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative, update_freq=update_freq)
96
159
  super().__init__(defaults)
97
160
 
98
161
  if inner is not None:
@@ -106,47 +169,66 @@ class Newton(Module):
106
169
 
107
170
  settings = self.settings[params[0]]
108
171
  reg = settings['reg']
109
- eig_reg = settings['eig_reg']
110
172
  search_negative = settings['search_negative']
111
173
  hessian_method = settings['hessian_method']
112
174
  vectorize = settings['vectorize']
113
175
  H_tfm = settings['H_tfm']
114
176
  eigval_tfm = settings['eigval_tfm']
177
+ update_freq = settings['update_freq']
178
+
179
+ step = self.global_state.get('step', 0)
180
+ self.global_state['step'] = step + 1
181
+
182
+ g_list = var.grad
183
+ H = None
184
+ if step % update_freq == 0:
185
+ # ------------------------ calculate grad and hessian ------------------------ #
186
+ if hessian_method == 'autograd':
187
+ with torch.enable_grad():
188
+ loss = var.loss = var.loss_approx = closure(False)
189
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
190
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
191
+ var.grad = g_list
192
+ H = hessian_list_to_mat(H_list)
115
193
 
116
- # ------------------------ calculate grad and hessian ------------------------ #
117
- if hessian_method == 'autograd':
118
- with torch.enable_grad():
119
- loss = var.loss = var.loss_approx = closure(False)
120
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
121
- g_list = [t[0] for t in g_list] # remove leading dim from loss
122
- var.grad = g_list
123
- H = hessian_list_to_mat(H_list)
124
-
125
- elif hessian_method in ('func', 'autograd.functional'):
126
- strat = 'forward-mode' if vectorize else 'reverse-mode'
127
- with torch.enable_grad():
128
- g_list = var.get_grad(retain_graph=True)
129
- H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
130
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
131
-
132
- else:
133
- raise ValueError(hessian_method)
194
+ elif hessian_method in ('func', 'autograd.functional'):
195
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
196
+ with torch.enable_grad():
197
+ g_list = var.get_grad(retain_graph=True)
198
+ H = hessian_mat(partial(closure, backward=False), params,
199
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
200
+
201
+ else:
202
+ raise ValueError(hessian_method)
203
+
204
+ H = tikhonov_(H, reg)
205
+ if update_freq != 1:
206
+ self.global_state['H'] = H
207
+
208
+ if H is None:
209
+ H = self.global_state["H"]
210
+
211
+ # var.storage['hessian'] = H
134
212
 
135
213
  # -------------------------------- inner step -------------------------------- #
136
214
  update = var.get_update()
137
215
  if 'inner' in self.children:
138
- update = apply_transform(self.children['inner'], update, params=params, grads=list(g_list), var=var)
216
+ update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
217
+
139
218
  g = torch.cat([t.ravel() for t in update])
140
219
 
141
- # ------------------------------- regulazition ------------------------------- #
142
- if eig_reg: H = eig_tikhonov_(H, reg)
143
- else: H = tikhonov_(H, reg)
144
220
 
145
221
  # ----------------------------------- solve ---------------------------------- #
146
222
  update = None
147
223
  if H_tfm is not None:
148
- H, is_inv = H_tfm(H, g)
149
- if is_inv: update = H @ g
224
+ ret = H_tfm(H, g)
225
+
226
+ if isinstance(ret, torch.Tensor):
227
+ update = ret
228
+
229
+ else: # returns (H, is_inv)
230
+ H, is_inv = ret
231
+ if is_inv: update = H @ g
150
232
 
151
233
  if search_negative or (eigval_tfm is not None):
152
234
  update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
@@ -156,4 +238,101 @@ class Newton(Module):
156
238
  if update is None: update = least_squares_solve(H, g)
157
239
 
158
240
  var.update = vec_to_tensors(update, params)
241
+
242
+ return var
243
+
244
+ class InverseFreeNewton(Module):
245
+ """Inverse-free newton's method
246
+
247
+ .. note::
248
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
249
+
250
+ .. note::
251
+ This module requires the a closure passed to the optimizer step,
252
+ as it needs to re-evaluate the loss and gradients for calculating the hessian.
253
+ The closure must accept a ``backward`` argument (refer to documentation).
254
+
255
+ .. warning::
256
+ this uses roughly O(N^2) memory.
257
+
258
+ Reference
259
+ Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
260
+ """
261
+ def __init__(
262
+ self,
263
+ update_freq: int = 1,
264
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
265
+ vectorize: bool = True,
266
+ inner: Chainable | None = None,
267
+ ):
268
+ defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
269
+ super().__init__(defaults)
270
+
271
+ if inner is not None:
272
+ self.set_child('inner', inner)
273
+
274
+ @torch.no_grad
275
+ def step(self, var):
276
+ params = TensorList(var.params)
277
+ closure = var.closure
278
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
279
+
280
+ settings = self.settings[params[0]]
281
+ hessian_method = settings['hessian_method']
282
+ vectorize = settings['vectorize']
283
+ update_freq = settings['update_freq']
284
+
285
+ step = self.global_state.get('step', 0)
286
+ self.global_state['step'] = step + 1
287
+
288
+ g_list = var.grad
289
+ Y = None
290
+ if step % update_freq == 0:
291
+ # ------------------------ calculate grad and hessian ------------------------ #
292
+ if hessian_method == 'autograd':
293
+ with torch.enable_grad():
294
+ loss = var.loss = var.loss_approx = closure(False)
295
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
296
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
297
+ var.grad = g_list
298
+ H = hessian_list_to_mat(H_list)
299
+
300
+ elif hessian_method in ('func', 'autograd.functional'):
301
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
302
+ with torch.enable_grad():
303
+ g_list = var.get_grad(retain_graph=True)
304
+ H = hessian_mat(partial(closure, backward=False), params,
305
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
306
+
307
+ else:
308
+ raise ValueError(hessian_method)
309
+
310
+ # inverse free part
311
+ if 'Y' not in self.global_state:
312
+ num = H.T
313
+ denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
314
+ eps = torch.finfo(H.dtype).eps
315
+ Y = self.global_state['Y'] = num.div_(denom.clip(min=eps, max=1/eps))
316
+
317
+ else:
318
+ Y = self.global_state['Y']
319
+ I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
320
+ I -= H @ Y
321
+ Y = self.global_state['Y'] = Y @ I
322
+
323
+ if Y is None:
324
+ Y = self.global_state["Y"]
325
+
326
+
327
+ # -------------------------------- inner step -------------------------------- #
328
+ update = var.get_update()
329
+ if 'inner' in self.children:
330
+ update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
331
+
332
+ g = torch.cat([t.ravel() for t in update])
333
+
334
+
335
+ # ----------------------------------- solve ---------------------------------- #
336
+ var.update = vec_to_tensors(Y@g, params)
337
+
159
338
  return var
@@ -1,26 +1,102 @@
1
- from collections.abc import Callable
2
1
  from typing import Literal, overload
3
- import warnings
4
2
  import torch
5
3
 
6
- from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
4
+ from ...utils import TensorList, as_tensorlist, NumberList
7
5
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
6
 
9
7
  from ...core import Chainable, apply_transform, Module
10
- from ...utils.linalg.solve import cg
8
+ from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
11
9
 
12
10
  class NewtonCG(Module):
11
+ """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
12
+
13
+ This optimizer implements Newton's method using a matrix-free conjugate
14
+ gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
15
+ forming the full Hessian matrix, it only requires Hessian-vector products
16
+ (HVPs). These can be calculated efficiently using automatic
17
+ differentiation or approximated using finite differences.
18
+
19
+ .. note::
20
+ In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
21
+
22
+ .. note::
23
+ This module requires the a closure passed to the optimizer step,
24
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
25
+ The closure must accept a ``backward`` argument (refer to documentation).
26
+
27
+ .. warning::
28
+ CG may fail if hessian is not positive-definite.
29
+
30
+ Args:
31
+ maxiter (int | None, optional):
32
+ Maximum number of iterations for the conjugate gradient solver.
33
+ By default, this is set to the number of dimensions in the
34
+ objective function, which is the theoretical upper bound for CG
35
+ convergence. Setting this to a smaller value (truncated Newton)
36
+ can still generate good search directions. Defaults to None.
37
+ tol (float, optional):
38
+ Relative tolerance for the conjugate gradient solver to determine
39
+ convergence. Defaults to 1e-4.
40
+ reg (float, optional):
41
+ Regularization parameter (damping) added to the Hessian diagonal.
42
+ This helps ensure the system is positive-definite. Defaults to 1e-8.
43
+ hvp_method (str, optional):
44
+ Determines how Hessian-vector products are evaluated.
45
+
46
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
47
+ This requires creating a graph for the gradient.
48
+ - ``"forward"``: Use a forward finite difference formula to
49
+ approximate the HVP. This requires one extra gradient evaluation.
50
+ - ``"central"``: Use a central finite difference formula for a
51
+ more accurate HVP approximation. This requires two extra
52
+ gradient evaluations.
53
+ Defaults to "autograd".
54
+ h (float, optional):
55
+ The step size for finite differences if :code:`hvp_method` is
56
+ ``"forward"`` or ``"central"``. Defaults to 1e-3.
57
+ warm_start (bool, optional):
58
+ If ``True``, the conjugate gradient solver is initialized with the
59
+ solution from the previous optimization step. This can accelerate
60
+ convergence, especially in truncated Newton methods.
61
+ Defaults to False.
62
+ inner (Chainable | None, optional):
63
+ NewtonCG will attempt to apply preconditioning to the output of this module.
64
+
65
+ Examples:
66
+ Newton-CG with a backtracking line search:
67
+
68
+ .. code-block:: python
69
+
70
+ opt = tz.Modular(
71
+ model.parameters(),
72
+ tz.m.NewtonCG(),
73
+ tz.m.Backtracking()
74
+ )
75
+
76
+ Truncated Newton method (useful for large-scale problems):
77
+
78
+ .. code-block:: python
79
+
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.NewtonCG(maxiter=10, warm_start=True),
83
+ tz.m.Backtracking()
84
+ )
85
+
86
+
87
+ """
13
88
  def __init__(
14
89
  self,
15
- maxiter=None,
16
- tol=1e-4,
90
+ maxiter: int | None = None,
91
+ tol: float = 1e-4,
17
92
  reg: float = 1e-8,
18
- hvp_method: Literal["forward", "central", "autograd"] = "forward",
19
- h=1e-3,
93
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
94
+ solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
95
+ h: float = 1e-3,
20
96
  warm_start=False,
21
97
  inner: Chainable | None = None,
22
98
  ):
23
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, warm_start=warm_start)
99
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
24
100
  super().__init__(defaults,)
25
101
 
26
102
  if inner is not None:
@@ -37,6 +113,7 @@ class NewtonCG(Module):
37
113
  reg = settings['reg']
38
114
  maxiter = settings['maxiter']
39
115
  hvp_method = settings['hvp_method']
116
+ solver = settings['solver'].lower().strip()
40
117
  h = settings['h']
41
118
  warm_start = settings['warm_start']
42
119
 
@@ -68,13 +145,25 @@ class NewtonCG(Module):
68
145
  # -------------------------------- inner step -------------------------------- #
69
146
  b = var.get_update()
70
147
  if 'inner' in self.children:
71
- b = as_tensorlist(apply_transform(self.children['inner'], b, params=params, grads=grad, var=var))
148
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
149
+ b = as_tensorlist(b)
72
150
 
73
151
  # ---------------------------------- run cg ---------------------------------- #
74
152
  x0 = None
75
153
  if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
76
154
 
77
- x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
155
+ if solver == 'cg':
156
+ x = cg(A_mm=H_mm, b=b, x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
157
+
158
+ elif solver == 'minres':
159
+ x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
160
+
161
+ elif solver == 'minres_npc':
162
+ x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
163
+
164
+ else:
165
+ raise ValueError(f"Unknown solver {solver}")
166
+
78
167
  if warm_start:
79
168
  assert x0 is not None
80
169
  x0.copy_(x)
@@ -83,3 +172,203 @@ class NewtonCG(Module):
83
172
  return var
84
173
 
85
174
 
175
+ class TruncatedNewtonCG(Module):
176
+ """Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
177
+
178
+ This optimizer implements Newton's method using a matrix-free conjugate
179
+ gradient (CG) solver to approximate the search direction. Instead of
180
+ forming the full Hessian matrix, it only requires Hessian-vector products
181
+ (HVPs). These can be calculated efficiently using automatic
182
+ differentiation or approximated using finite differences.
183
+
184
+ .. note::
185
+ In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
186
+
187
+ .. note::
188
+ This module requires the a closure passed to the optimizer step,
189
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
190
+ The closure must accept a ``backward`` argument (refer to documentation).
191
+
192
+ .. warning::
193
+ CG may fail if hessian is not positive-definite.
194
+
195
+ Args:
196
+ maxiter (int | None, optional):
197
+ Maximum number of iterations for the conjugate gradient solver.
198
+ By default, this is set to the number of dimensions in the
199
+ objective function, which is the theoretical upper bound for CG
200
+ convergence. Setting this to a smaller value (truncated Newton)
201
+ can still generate good search directions. Defaults to None.
202
+ eta (float, optional):
203
+ whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
204
+ nplus (float, optional):
205
+ trust region multiplier on successful steps.
206
+ nminus (float, optional):
207
+ trust region multiplier on unsuccessful steps.
208
+ init (float, optional): initial trust region.
209
+ tol (float, optional):
210
+ Relative tolerance for the conjugate gradient solver to determine
211
+ convergence. Defaults to 1e-4.
212
+ reg (float, optional):
213
+ Regularization parameter (damping) added to the Hessian diagonal.
214
+ This helps ensure the system is positive-definite. Defaults to 1e-8.
215
+ hvp_method (str, optional):
216
+ Determines how Hessian-vector products are evaluated.
217
+
218
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
219
+ This requires creating a graph for the gradient.
220
+ - ``"forward"``: Use a forward finite difference formula to
221
+ approximate the HVP. This requires one extra gradient evaluation.
222
+ - ``"central"``: Use a central finite difference formula for a
223
+ more accurate HVP approximation. This requires two extra
224
+ gradient evaluations.
225
+ Defaults to "autograd".
226
+ h (float, optional):
227
+ The step size for finite differences if :code:`hvp_method` is
228
+ ``"forward"`` or ``"central"``. Defaults to 1e-3.
229
+ inner (Chainable | None, optional):
230
+ NewtonCG will attempt to apply preconditioning to the output of this module.
231
+
232
+ Examples:
233
+ Trust-region Newton-CG:
234
+
235
+ .. code-block:: python
236
+
237
+ opt = tz.Modular(
238
+ model.parameters(),
239
+ tz.m.NewtonCGSteihaug(),
240
+ )
241
+
242
+ Reference:
243
+ Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
244
+ """
245
+ def __init__(
246
+ self,
247
+ maxiter: int | None = None,
248
+ eta: float= 1e-6,
249
+ nplus: float = 2,
250
+ nminus: float = 0.25,
251
+ init: float = 1,
252
+ tol: float = 1e-4,
253
+ reg: float = 1e-8,
254
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
255
+ solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
256
+ h: float = 1e-3,
257
+ max_attempts: int = 10,
258
+ inner: Chainable | None = None,
259
+ ):
260
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, eta=eta, nplus=nplus, nminus=nminus, init=init, max_attempts=max_attempts, solver=solver)
261
+ super().__init__(defaults,)
262
+
263
+ if inner is not None:
264
+ self.set_child('inner', inner)
265
+
266
+ @torch.no_grad
267
+ def step(self, var):
268
+ params = TensorList(var.params)
269
+ closure = var.closure
270
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
271
+
272
+ settings = self.settings[params[0]]
273
+ tol = settings['tol']
274
+ reg = settings['reg']
275
+ maxiter = settings['maxiter']
276
+ hvp_method = settings['hvp_method']
277
+ h = settings['h']
278
+ max_attempts = settings['max_attempts']
279
+ solver = settings['solver'].lower().strip()
280
+
281
+ eta = settings['eta']
282
+ nplus = settings['nplus']
283
+ nminus = settings['nminus']
284
+ init = settings['init']
285
+
286
+ # ---------------------- Hessian vector product function --------------------- #
287
+ if hvp_method == 'autograd':
288
+ grad = var.get_grad(create_graph=True)
289
+
290
+ def H_mm(x):
291
+ with torch.enable_grad():
292
+ return TensorList(hvp(params, grad, x, retain_graph=True))
293
+
294
+ else:
295
+
296
+ with torch.enable_grad():
297
+ grad = var.get_grad()
298
+
299
+ if hvp_method == 'forward':
300
+ def H_mm(x):
301
+ return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
302
+
303
+ elif hvp_method == 'central':
304
+ def H_mm(x):
305
+ return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
306
+
307
+ else:
308
+ raise ValueError(hvp_method)
309
+
310
+
311
+ # -------------------------------- inner step -------------------------------- #
312
+ b = var.get_update()
313
+ if 'inner' in self.children:
314
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
315
+ b = as_tensorlist(b)
316
+
317
+ # ---------------------------------- run cg ---------------------------------- #
318
+ success = False
319
+ x = None
320
+ while not success:
321
+ max_attempts -= 1
322
+ if max_attempts < 0: break
323
+
324
+ trust_region = self.global_state.get('trust_region', init)
325
+ if trust_region < 1e-8 or trust_region > 1e8:
326
+ trust_region = self.global_state['trust_region'] = init
327
+
328
+ if solver == 'cg':
329
+ x = steihaug_toint_cg(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg)
330
+
331
+ elif solver == 'minres':
332
+ x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
333
+
334
+ elif solver == 'minres_npc':
335
+ x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
336
+
337
+ else:
338
+ raise ValueError(f"unknown solver {solver}")
339
+
340
+ # ------------------------------- trust region ------------------------------- #
341
+ Hx = H_mm(x)
342
+ pred_reduction = b.dot(x) - 0.5 * x.dot(Hx)
343
+
344
+ params -= x
345
+ loss_star = closure(False)
346
+ params += x
347
+ reduction = var.get_loss(False) - loss_star
348
+
349
+ rho = reduction / (pred_reduction.clip(min=1e-8))
350
+
351
+ # failed step
352
+ if rho < 0.25:
353
+ self.global_state['trust_region'] = trust_region * nminus
354
+
355
+ # very good step
356
+ elif rho > 0.75:
357
+ diff = trust_region - x.abs()
358
+ if (diff.global_min() / trust_region) > 1e-4: # hits boundary
359
+ self.global_state['trust_region'] = trust_region * nplus
360
+
361
+ # if the ratio is high enough then accept the proposed step
362
+ if rho > eta:
363
+ success = True
364
+
365
+ assert x is not None
366
+ if success:
367
+ var.update = x
368
+
369
+ else:
370
+ var.update = params.zeros_like()
371
+
372
+ return var
373
+
374
+