torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,30 +1,24 @@
1
- from typing import Literal, overload
1
+ import warnings
2
+ import math
3
+ from typing import Literal, cast
4
+ from operator import itemgetter
2
5
  import torch
3
6
 
4
- from ...utils import TensorList, as_tensorlist, NumberList
7
+ from ...core import Chainable, Module, apply_transform
8
+ from ...utils import TensorList, as_tensorlist, tofloat
5
9
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
6
-
7
- from ...core import Chainable, apply_transform, Module
8
- from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
10
+ from ...utils.linalg.solve import cg, minres, find_within_trust_radius
11
+ from ..trust_region.trust_region import default_radius
9
12
 
10
13
  class NewtonCG(Module):
11
14
  """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
12
15
 
13
- This optimizer implements Newton's method using a matrix-free conjugate
14
- gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
15
- forming the full Hessian matrix, it only requires Hessian-vector products
16
- (HVPs). These can be calculated efficiently using automatic
17
- differentiation or approximated using finite differences.
18
-
19
- .. note::
20
- In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
16
+ Notes:
17
+ * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
21
18
 
22
- .. note::
23
- This module requires the a closure passed to the optimizer step,
24
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
25
- The closure must accept a ``backward`` argument (refer to documentation).
19
+ * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
26
20
 
27
- .. warning::
21
+ Warning:
28
22
  CG may fail if hessian is not positive-definite.
29
23
 
30
24
  Args:
@@ -63,45 +57,48 @@ class NewtonCG(Module):
63
57
  NewtonCG will attempt to apply preconditioning to the output of this module.
64
58
 
65
59
  Examples:
66
- Newton-CG with a backtracking line search:
67
-
68
- .. code-block:: python
69
-
70
- opt = tz.Modular(
71
- model.parameters(),
72
- tz.m.NewtonCG(),
73
- tz.m.Backtracking()
74
- )
75
-
76
- Truncated Newton method (useful for large-scale problems):
77
-
78
- .. code-block:: python
79
-
80
- opt = tz.Modular(
81
- model.parameters(),
82
- tz.m.NewtonCG(maxiter=10, warm_start=True),
83
- tz.m.Backtracking()
84
- )
85
-
60
+ Newton-CG with a backtracking line search:
61
+
62
+ ```python
63
+ opt = tz.Modular(
64
+ model.parameters(),
65
+ tz.m.NewtonCG(),
66
+ tz.m.Backtracking()
67
+ )
68
+ ```
69
+
70
+ Truncated Newton method (useful for large-scale problems):
71
+ ```
72
+ opt = tz.Modular(
73
+ model.parameters(),
74
+ tz.m.NewtonCG(maxiter=10),
75
+ tz.m.Backtracking()
76
+ )
77
+ ```
86
78
 
87
79
  """
88
80
  def __init__(
89
81
  self,
90
82
  maxiter: int | None = None,
91
- tol: float = 1e-4,
83
+ tol: float = 1e-8,
92
84
  reg: float = 1e-8,
93
85
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
94
86
  solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
95
- h: float = 1e-3,
87
+ h: float = 1e-3, # tuned 1e-4 or 1e-3
88
+ miniter:int = 1,
96
89
  warm_start=False,
97
90
  inner: Chainable | None = None,
98
91
  ):
99
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
92
+ defaults = locals().copy()
93
+ del defaults['self'], defaults['inner']
100
94
  super().__init__(defaults,)
101
95
 
102
96
  if inner is not None:
103
97
  self.set_child('inner', inner)
104
98
 
99
+ self._num_hvps = 0
100
+ self._num_hvps_last_step = 0
101
+
105
102
  @torch.no_grad
106
103
  def step(self, var):
107
104
  params = TensorList(var.params)
@@ -117,11 +114,13 @@ class NewtonCG(Module):
117
114
  h = settings['h']
118
115
  warm_start = settings['warm_start']
119
116
 
117
+ self._num_hvps_last_step = 0
120
118
  # ---------------------- Hessian vector product function --------------------- #
121
119
  if hvp_method == 'autograd':
122
120
  grad = var.get_grad(create_graph=True)
123
121
 
124
122
  def H_mm(x):
123
+ self._num_hvps_last_step += 1
125
124
  with torch.enable_grad():
126
125
  return TensorList(hvp(params, grad, x, retain_graph=True))
127
126
 
@@ -132,10 +131,12 @@ class NewtonCG(Module):
132
131
 
133
132
  if hvp_method == 'forward':
134
133
  def H_mm(x):
134
+ self._num_hvps_last_step += 1
135
135
  return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
136
136
 
137
137
  elif hvp_method == 'central':
138
138
  def H_mm(x):
139
+ self._num_hvps_last_step += 1
139
140
  return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
140
141
 
141
142
  else:
@@ -153,141 +154,154 @@ class NewtonCG(Module):
153
154
  if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
154
155
 
155
156
  if solver == 'cg':
156
- x = cg(A_mm=H_mm, b=b, x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
157
+ d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)
157
158
 
158
159
  elif solver == 'minres':
159
- x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
160
+ d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
160
161
 
161
162
  elif solver == 'minres_npc':
162
- x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
163
+ d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
163
164
 
164
165
  else:
165
166
  raise ValueError(f"Unknown solver {solver}")
166
167
 
167
168
  if warm_start:
168
169
  assert x0 is not None
169
- x0.copy_(x)
170
-
171
- var.update = x
172
- return var
170
+ x0.copy_(d)
173
171
 
172
+ var.update = d
174
173
 
175
- class TruncatedNewtonCG(Module):
176
- """Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
174
+ self._num_hvps += self._num_hvps_last_step
175
+ return var
177
176
 
178
- This optimizer implements Newton's method using a matrix-free conjugate
179
- gradient (CG) solver to approximate the search direction. Instead of
180
- forming the full Hessian matrix, it only requires Hessian-vector products
181
- (HVPs). These can be calculated efficiently using automatic
182
- differentiation or approximated using finite differences.
183
177
 
184
- .. note::
185
- In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
178
+ class NewtonCGSteihaug(Module):
179
+ """Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
186
180
 
187
- .. note::
188
- This module requires the a closure passed to the optimizer step,
189
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
190
- The closure must accept a ``backward`` argument (refer to documentation).
181
+ Notes:
182
+ * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
191
183
 
192
- .. warning::
193
- CG may fail if hessian is not positive-definite.
184
+ * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
194
185
 
195
186
  Args:
196
- maxiter (int | None, optional):
197
- Maximum number of iterations for the conjugate gradient solver.
198
- By default, this is set to the number of dimensions in the
199
- objective function, which is the theoretical upper bound for CG
200
- convergence. Setting this to a smaller value (truncated Newton)
201
- can still generate good search directions. Defaults to None.
202
187
  eta (float, optional):
203
- whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
204
- nplus (float, optional):
205
- trust region multiplier on successful steps.
206
- nminus (float, optional):
207
- trust region multiplier on unsuccessful steps.
208
- init (float, optional): initial trust region.
188
+ if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
189
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
190
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
191
+ rho_good (float, optional):
192
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
193
+ rho_bad (float, optional):
194
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
195
+ init (float, optional): Initial trust region value. Defaults to 1.
196
+ max_attempts (max_attempts, optional):
197
+ maximum number of trust radius reductions per step. A zero update vector is returned when
198
+ this limit is exceeded. Defaults to 10.
199
+ max_history (int, optional):
200
+ CG will store this many intermediate solutions, reusing them when trust radius is reduced
201
+ instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
202
+ boundary_tol (float | None, optional):
203
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
204
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
205
+
206
+ maxiter (int | None, optional):
207
+ maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
208
+ miniter (int, optional):
209
+ minimal number of CG iterations. This prevents making no progress
209
210
  tol (float, optional):
210
- Relative tolerance for the conjugate gradient solver to determine
211
- convergence. Defaults to 1e-4.
212
- reg (float, optional):
213
- Regularization parameter (damping) added to the Hessian diagonal.
214
- This helps ensure the system is positive-definite. Defaults to 1e-8.
211
+ terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
212
+ when initial guess is below tolerance. Defaults to 1.
213
+ reg (float, optional): hessian regularization. Defaults to 1e-8.
214
+ solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
215
+ adapt_tol (bool, optional):
216
+ if True, whenever trust radius collapses to smallest representable number,
217
+ the tolerance is multiplied by 0.1. Defaults to True.
218
+ npc_terminate (bool, optional):
219
+ whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
220
+
215
221
  hvp_method (str, optional):
216
- Determines how Hessian-vector products are evaluated.
222
+ either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
223
+ h (float, optional): finite difference step size. Defaults to 1e-3.
217
224
 
218
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
219
- This requires creating a graph for the gradient.
220
- - ``"forward"``: Use a forward finite difference formula to
221
- approximate the HVP. This requires one extra gradient evaluation.
222
- - ``"central"``: Use a central finite difference formula for a
223
- more accurate HVP approximation. This requires two extra
224
- gradient evaluations.
225
- Defaults to "autograd".
226
- h (float, optional):
227
- The step size for finite differences if :code:`hvp_method` is
228
- ``"forward"`` or ``"central"``. Defaults to 1e-3.
229
225
  inner (Chainable | None, optional):
230
- NewtonCG will attempt to apply preconditioning to the output of this module.
231
-
232
- Examples:
233
- Trust-region Newton-CG:
226
+ applies preconditioning to output of this module. Defaults to None.
234
227
 
235
- .. code-block:: python
228
+ ### Examples:
229
+ Trust-region Newton-CG:
236
230
 
237
- opt = tz.Modular(
238
- model.parameters(),
239
- tz.m.NewtonCGSteihaug(),
240
- )
231
+ ```python
232
+ opt = tz.Modular(
233
+ model.parameters(),
234
+ tz.m.NewtonCGSteihaug(),
235
+ )
236
+ ```
241
237
 
242
- Reference:
238
+ ### Reference:
243
239
  Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
244
240
  """
245
241
  def __init__(
246
242
  self,
247
- maxiter: int | None = None,
248
- eta: float= 1e-6,
249
- nplus: float = 2,
243
+ # trust region settings
244
+ eta: float= 0.0,
245
+ nplus: float = 3.5,
250
246
  nminus: float = 0.25,
247
+ rho_good: float = 0.99,
248
+ rho_bad: float = 1e-4,
251
249
  init: float = 1,
252
- tol: float = 1e-4,
250
+ max_attempts: int = 100,
251
+ max_history: int = 100,
252
+ boundary_tol: float = 1e-6, # tuned
253
+
254
+ # cg settings
255
+ maxiter: int | None = None,
256
+ miniter: int = 1,
257
+ tol: float = 1e-8,
253
258
  reg: float = 1e-8,
254
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
255
- solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
256
- h: float = 1e-3,
257
- max_attempts: int = 10,
259
+ solver: Literal['cg', "minres"] = 'cg',
260
+ adapt_tol: bool = True,
261
+ npc_terminate: bool = False,
262
+
263
+ # hvp settings
264
+ hvp_method: Literal["forward", "central"] = "central",
265
+ h: float = 1e-3, # tuned 1e-4 or 1e-3
266
+
267
+ # inner
258
268
  inner: Chainable | None = None,
259
269
  ):
260
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, eta=eta, nplus=nplus, nminus=nminus, init=init, max_attempts=max_attempts, solver=solver)
270
+ defaults = locals().copy()
271
+ del defaults['self'], defaults['inner']
261
272
  super().__init__(defaults,)
262
273
 
263
274
  if inner is not None:
264
275
  self.set_child('inner', inner)
265
276
 
277
+ self._num_hvps = 0
278
+ self._num_hvps_last_step = 0
279
+
266
280
  @torch.no_grad
267
281
  def step(self, var):
268
282
  params = TensorList(var.params)
269
283
  closure = var.closure
270
284
  if closure is None: raise RuntimeError('NewtonCG requires closure')
271
285
 
272
- settings = self.settings[params[0]]
273
- tol = settings['tol']
274
- reg = settings['reg']
275
- maxiter = settings['maxiter']
276
- hvp_method = settings['hvp_method']
277
- h = settings['h']
278
- max_attempts = settings['max_attempts']
279
- solver = settings['solver'].lower().strip()
286
+ tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
287
+ solver = self.defaults['solver'].lower().strip()
288
+
289
+ (reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
290
+ eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
291
+ miniter, max_history, adapt_tol) = itemgetter(
292
+ "reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
293
+ "eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
294
+ "miniter", "max_history", "adapt_tol",
295
+ )(self.defaults)
280
296
 
281
- eta = settings['eta']
282
- nplus = settings['nplus']
283
- nminus = settings['nminus']
284
- init = settings['init']
297
+ self._num_hvps_last_step = 0
285
298
 
286
299
  # ---------------------- Hessian vector product function --------------------- #
287
300
  if hvp_method == 'autograd':
288
301
  grad = var.get_grad(create_graph=True)
289
302
 
290
303
  def H_mm(x):
304
+ self._num_hvps_last_step += 1
291
305
  with torch.enable_grad():
292
306
  return TensorList(hvp(params, grad, x, retain_graph=True))
293
307
 
@@ -298,10 +312,12 @@ class TruncatedNewtonCG(Module):
298
312
 
299
313
  if hvp_method == 'forward':
300
314
  def H_mm(x):
315
+ self._num_hvps_last_step += 1
301
316
  return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
302
317
 
303
318
  elif hvp_method == 'central':
304
319
  def H_mm(x):
320
+ self._num_hvps_last_step += 1
305
321
  return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
306
322
 
307
323
  else:
@@ -314,61 +330,82 @@ class TruncatedNewtonCG(Module):
314
330
  b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
315
331
  b = as_tensorlist(b)
316
332
 
317
- # ---------------------------------- run cg ---------------------------------- #
333
+ # ------------------------------- trust region ------------------------------- #
318
334
  success = False
319
- x = None
335
+ d = None
336
+ x0 = [p.clone() for p in params]
337
+ solution = None
338
+
320
339
  while not success:
321
340
  max_attempts -= 1
322
341
  if max_attempts < 0: break
323
342
 
324
- trust_region = self.global_state.get('trust_region', init)
325
- if trust_region < 1e-8 or trust_region > 1e8:
326
- trust_region = self.global_state['trust_region'] = init
327
-
328
- if solver == 'cg':
329
- x = steihaug_toint_cg(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg)
330
-
331
- elif solver == 'minres':
332
- x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
333
-
334
- elif solver == 'minres_npc':
335
- x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
336
-
337
- else:
338
- raise ValueError(f"unknown solver {solver}")
339
-
340
- # ------------------------------- trust region ------------------------------- #
341
- Hx = H_mm(x)
342
- pred_reduction = b.dot(x) - 0.5 * x.dot(Hx)
343
-
344
- params -= x
345
- loss_star = closure(False)
346
- params += x
347
- reduction = var.get_loss(False) - loss_star
348
-
349
- rho = reduction / (pred_reduction.clip(min=1e-8))
350
-
351
- # failed step
352
- if rho < 0.25:
353
- self.global_state['trust_region'] = trust_region * nminus
354
-
355
- # very good step
356
- elif rho > 0.75:
357
- diff = trust_region - x.abs()
358
- if (diff.global_min() / trust_region) > 1e-4: # hits boundary
359
- self.global_state['trust_region'] = trust_region * nplus
360
-
361
- # if the ratio is high enough then accept the proposed step
362
- if rho > eta:
363
- success = True
343
+ trust_radius = self.global_state.get('trust_radius', init)
344
+
345
+ # -------------- make sure trust radius isn't too small or large ------------- #
346
+ finfo = torch.finfo(x0[0].dtype)
347
+ if trust_radius < finfo.tiny * 2:
348
+ trust_radius = self.global_state['trust_radius'] = init
349
+ if adapt_tol:
350
+ self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
351
+
352
+ elif trust_radius > finfo.max / 2:
353
+ trust_radius = self.global_state['trust_radius'] = init
354
+
355
+ # ----------------------------------- solve ---------------------------------- #
356
+ d = None
357
+ if solution is not None and solution.history is not None:
358
+ d = find_within_trust_radius(solution.history, trust_radius)
359
+
360
+ if d is None:
361
+ if solver == 'cg':
362
+ d, solution = cg(
363
+ A_mm=H_mm,
364
+ b=b,
365
+ tol=tol,
366
+ maxiter=maxiter,
367
+ reg=reg,
368
+ trust_radius=trust_radius,
369
+ miniter=miniter,
370
+ npc_terminate=npc_terminate,
371
+ history_size=max_history,
372
+ )
373
+
374
+ elif solver == 'minres':
375
+ d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
376
+
377
+ else:
378
+ raise ValueError(f"unknown solver {solver}")
379
+
380
+ # ---------------------------- update trust radius --------------------------- #
381
+ self.global_state["trust_radius"], success = default_radius(
382
+ params=params,
383
+ closure=closure,
384
+ f=tofloat(var.get_loss(False)),
385
+ g=b,
386
+ H=H_mm,
387
+ d=d,
388
+ trust_radius=trust_radius,
389
+ eta=eta,
390
+ nplus=nplus,
391
+ nminus=nminus,
392
+ rho_good=rho_good,
393
+ rho_bad=rho_bad,
394
+ boundary_tol=boundary_tol,
395
+
396
+ init=init, # init isn't used because check_overflow=False
397
+ state=self.global_state, # not used
398
+ settings=self.defaults, # not used
399
+ check_overflow=False, # this is checked manually to adapt tolerance
400
+ )
364
401
 
365
- assert x is not None
402
+ # --------------------------- assign new direction --------------------------- #
403
+ assert d is not None
366
404
  if success:
367
- var.update = x
405
+ var.update = d
368
406
 
369
407
  else:
370
408
  var.update = params.zeros_like()
371
409
 
372
- return var
373
-
374
-
410
+ self._num_hvps += self._num_hvps_last_step
411
+ return var
@@ -1,2 +1,2 @@
1
1
  from .laplacian import LaplacianSmoothing
2
- from .gaussian import GaussianHomotopy
2
+ from .sampling import GradientSampling