torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -13,7 +13,7 @@ import torch
13
13
  from ...core import Chainable, Module, apply_transform
14
14
  from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
15
  from ...utils.derivatives import (
16
- hessian_list_to_mat,
16
+ flatten_jacobian,
17
17
  jacobian_wrt,
18
18
  )
19
19
 
@@ -70,57 +70,94 @@ def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
70
70
  def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
71
71
  derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
72
72
  x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
73
- bounds = None
74
- if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
75
73
 
76
- # if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
77
- if bounds is None:
78
- if len(derivatives) == 1: method = 'bfgs'
79
- else: method = 'trust-exact'
74
+ # notes
75
+ # 1. since we have exact hessian we use trust methods
76
+
77
+ # 2. if len(derivatives) is 1, only gradient is available,
78
+ # thus use slsqp depending on whether trust region is enabled
79
+ # this is just so that I can test that trust region works
80
+ if trust_region is None:
81
+ if len(derivatives) == 1: raise RuntimeError("trust region must be enabled because 1st order has no minima")
82
+ method = 'trust-exact'
83
+ de_bounds = list(zip(x0 - 10, x0 + 10))
84
+ constraints = None
85
+
80
86
  else:
81
- if len(derivatives) == 1: method = 'l-bfgs-b'
87
+ if len(derivatives) == 1: method = 'slsqp'
82
88
  else: method = 'trust-constr'
89
+ de_bounds = list(zip(x0 - trust_region, x0 + trust_region))
90
+
91
+ def l2_bound_f(x):
92
+ if x.ndim == 2: return np.sum((x - x0[:,None])**2, axis=0)[None,:] # DE passes (ndim, batch_size) and expects (M, S)
93
+ return np.sum((x - x0)**2, axis=0)
94
+
95
+ def l2_bound_g(x):
96
+ return 2 * (x - x0)
97
+
98
+ def l2_bound_h(x, v):
99
+ return v[0] * 2 * np.eye(x0.shape[0])
100
+
101
+ constraint = scipy.optimize.NonlinearConstraint(
102
+ fun=l2_bound_f,
103
+ lb=0, # 0 <= ||x-x0||^2
104
+ ub=trust_region**2, # ||x-x0||^2 <= R^2
105
+ jac=l2_bound_g, # pyright:ignore[reportArgumentType]
106
+ hess=l2_bound_h,
107
+ keep_feasible=False
108
+ )
109
+ constraints = [constraint]
83
110
 
84
111
  x_init = x0.copy()
85
112
  v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
113
+
114
+ # ---------------------------------- run DE ---------------------------------- #
86
115
  if de_iters is not None and de_iters != 0:
87
116
  if de_iters == -1: de_iters = None # let scipy decide
117
+
118
+ # DE needs bounds so use linf ig
88
119
  res = scipy.optimize.differential_evolution(
89
120
  _proximal_poly_v,
90
- bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
121
+ de_bounds,
91
122
  args=(c, prox, x0.copy(), derivatives),
92
123
  maxiter=de_iters,
93
124
  vectorized=True,
125
+ constraints = constraints,
126
+ updating='deferred',
94
127
  )
95
- if res.fun < v0: x_init = res.x
96
-
97
- res = scipy.optimize.minimize(
98
- _proximal_poly_v,
99
- x_init,
100
- method=method,
101
- args=(c, prox, x0.copy(), derivatives),
102
- jac=_proximal_poly_g,
103
- hess=_proximal_poly_H,
104
- bounds=bounds
105
- )
128
+ if res.fun < v0 and np.all(np.isfinite(res.x)): x_init = res.x
106
129
 
130
+ # ------------------------------- run minimize ------------------------------- #
131
+ try:
132
+ res = scipy.optimize.minimize(
133
+ _proximal_poly_v,
134
+ x_init,
135
+ method=method,
136
+ args=(c, prox, x0.copy(), derivatives),
137
+ jac=_proximal_poly_g,
138
+ hess=_proximal_poly_H,
139
+ constraints = constraints,
140
+ )
141
+ except ValueError:
142
+ return x, -float('inf')
107
143
  return torch.from_numpy(res.x).to(x), res.fun
108
144
 
109
145
 
110
146
 
111
147
  class HigherOrderNewton(Module):
112
- """
113
- A basic arbitrary order newton's method with optional trust region and proximal penalty.
114
- It is recommended to enable at least one of trust region or proximal penalty.
148
+ """A basic arbitrary order newton's method with optional trust region and proximal penalty.
115
149
 
116
150
  This constructs an nth order taylor approximation via autograd and minimizes it with
117
- scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
151
+ ``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.
118
152
 
119
- This uses n^order memory, where n is number of decision variables, and I am not aware
120
- of any problems where this is more efficient than newton's method. It can minimize
121
- rosenbrock in a single step, but that step probably takes more time than newton.
122
- And there are way more efficient tensor methods out there but they tend to be
123
- significantly more complex.
153
+ The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
154
+ so it can be more efficient in very specific instances.
155
+
156
+ Notes:
157
+ - In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
158
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
159
+ - this uses roughly O(N^order) memory and solving the subproblem is very expensive.
160
+ - "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
124
161
 
125
162
  Args:
126
163
 
@@ -136,7 +173,7 @@ class HigherOrderNewton(Module):
136
173
  increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
137
174
  decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
138
175
  trust_init (float | None, optional):
139
- initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on :code:`"proximal"`. Defaults to None.
176
+ initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
140
177
  trust_tol (float, optional):
141
178
  Maximum ratio of expected loss reduction to actual reduction for trust region increase.
142
179
  Should 1 or higer. Defaults to 2.
@@ -149,38 +186,43 @@ class HigherOrderNewton(Module):
149
186
  self,
150
187
  order: int = 4,
151
188
  trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
152
- increase: float = 1.5,
153
- decrease: float = 0.75,
154
- trust_init: float | None = None,
155
- trust_tol: float = 2,
189
+ nplus: float = 3.5,
190
+ nminus: float = 0.25,
191
+ rho_good: float = 0.99,
192
+ rho_bad: float = 1e-4,
193
+ init: float | None = None,
194
+ eta: float = 1e-6,
195
+ max_attempts = 10,
196
+ boundary_tol: float = 1e-2,
156
197
  de_iters: int | None = None,
157
198
  vectorize: bool = True,
158
199
  ):
159
- if trust_init is None:
160
- if trust_method == 'bounds': trust_init = 1
161
- else: trust_init = 0.1
200
+ if init is None:
201
+ if trust_method == 'bounds': init = 1
202
+ else: init = 0.1
162
203
 
163
- defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
204
+ defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
164
205
  super().__init__(defaults)
165
206
 
166
207
  @torch.no_grad
167
208
  def step(self, var):
168
209
  params = TensorList(var.params)
169
210
  closure = var.closure
170
- if closure is None: raise RuntimeError('NewtonCG requires closure')
211
+ if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
171
212
 
172
213
  settings = self.settings[params[0]]
173
214
  order = settings['order']
174
- increase = settings['increase']
175
- decrease = settings['decrease']
176
- trust_tol = settings['trust_tol']
177
- trust_init = settings['trust_init']
215
+ nplus = settings['nplus']
216
+ nminus = settings['nminus']
217
+ eta = settings['eta']
218
+ init = settings['init']
178
219
  trust_method = settings['trust_method']
179
220
  de_iters = settings['de_iters']
221
+ max_attempts = settings['max_attempts']
180
222
  vectorize = settings['vectorize']
181
-
182
- trust_value = self.global_state.get('trust_value', trust_init)
183
-
223
+ boundary_tol = settings['boundary_tol']
224
+ rho_good = settings['rho_good']
225
+ rho_bad = settings['rho_bad']
184
226
 
185
227
  # ------------------------ calculate grad and hessian ------------------------ #
186
228
  with torch.enable_grad():
@@ -200,57 +242,86 @@ class HigherOrderNewton(Module):
200
242
  T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
201
243
  with torch.no_grad() if is_last else nullcontext():
202
244
  # the shape is (ndim, ) * order
203
- T = hessian_list_to_mat(T_list).view(n, n, *T.shape[1:])
245
+ T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
204
246
  derivatives.append(T)
205
247
 
206
248
  x0 = torch.cat([p.ravel() for p in params])
207
249
 
208
- if trust_method is None: trust_method = 'none'
209
- else: trust_method = trust_method.lower()
210
-
211
- if trust_method == 'none':
212
- trust_region = None
213
- prox = 0
214
-
215
- elif trust_method == 'bounds':
216
- trust_region = trust_value
217
- prox = 0
218
-
219
- elif trust_method == 'proximal':
220
- trust_region = None
221
- prox = 1 / trust_value
222
-
250
+ success = False
251
+ x_star = None
252
+ while not success:
253
+ max_attempts -= 1
254
+ if max_attempts < 0: break
255
+
256
+ # load trust region value
257
+ trust_value = self.global_state.get('trust_region', init)
258
+
259
+ # make sure its not too small or too large
260
+ finfo = torch.finfo(x0.dtype)
261
+ if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
262
+ trust_value = self.global_state['trust_region'] = settings['init']
263
+
264
+ # determine tr and prox values
265
+ if trust_method is None: trust_method = 'none'
266
+ else: trust_method = trust_method.lower()
267
+
268
+ if trust_method == 'none':
269
+ trust_region = None
270
+ prox = 0
271
+
272
+ elif trust_method == 'bounds':
273
+ trust_region = trust_value
274
+ prox = 0
275
+
276
+ elif trust_method == 'proximal':
277
+ trust_region = None
278
+ prox = 1 / trust_value
279
+
280
+ else:
281
+ raise ValueError(trust_method)
282
+
283
+ # minimize the model
284
+ x_star, expected_loss = _poly_minimize(
285
+ trust_region=trust_region,
286
+ prox=prox,
287
+ de_iters=de_iters,
288
+ c=loss.item(),
289
+ x=x0,
290
+ derivatives=derivatives,
291
+ )
292
+
293
+ # update trust region
294
+ if trust_method == 'none':
295
+ success = True
296
+ else:
297
+ pred_reduction = loss - expected_loss
298
+
299
+ vec_to_tensors_(x_star, params)
300
+ loss_star = closure(False)
301
+ vec_to_tensors_(x0, params)
302
+ reduction = loss - loss_star
303
+
304
+ rho = reduction / (max(pred_reduction, 1e-8))
305
+ # failed step
306
+ if rho < rho_bad:
307
+ self.global_state['trust_region'] = trust_value * nminus
308
+
309
+ # very good step
310
+ elif rho > rho_good:
311
+ step = (x_star - x0)
312
+ magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
313
+ if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
314
+ # close to boundary
315
+ self.global_state['trust_region'] = trust_value * nplus
316
+
317
+ # if the ratio is high enough then accept the proposed step
318
+ success = rho > eta
319
+
320
+ assert x_star is not None
321
+ if success:
322
+ difference = vec_to_tensors(x0 - x_star, params)
323
+ var.update = list(difference)
223
324
  else:
224
- raise ValueError(trust_method)
225
-
226
- x_star, expected_loss = _poly_minimize(
227
- trust_region=trust_region,
228
- prox=prox,
229
- de_iters=de_iters,
230
- c=loss.item(),
231
- x=x0,
232
- derivatives=derivatives,
233
- )
234
-
235
- # trust region
236
- if trust_method != 'none':
237
- expected_reduction = loss - expected_loss
238
-
239
- vec_to_tensors_(x_star, params)
240
- loss_star = closure(False)
241
- vec_to_tensors_(x0, params)
242
- reduction = loss - loss_star
243
-
244
- # failed step
245
- if reduction <= 0:
246
- x_star = x0
247
- self.global_state['trust_value'] = trust_value * decrease
248
-
249
- # very good step
250
- elif expected_reduction / reduction <= trust_tol:
251
- self.global_state['trust_value'] = trust_value * increase
252
-
253
- difference = vec_to_tensors(x0 - x_star, params)
254
- var.update = list(difference)
325
+ var.update = params.zeros_like()
255
326
  return var
256
327
 
@@ -0,0 +1 @@
1
+ from .gn import SumOfSquares, GaussNewton
@@ -0,0 +1,161 @@
1
+ import torch
2
+ from ...core import Module
3
+
4
+ from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
+ from ...utils import vec_to_tensors
6
+ from ...utils.linalg import linear_operator
7
+ class SumOfSquares(Module):
8
+ """Sets loss to be the sum of squares of values returned by the closure.
9
+
10
+ This is meant to be used to test least squares methods against ordinary minimization methods.
11
+
12
+ To use this, the closure should return a vector of values to minimize sum of squares of.
13
+ Please add the `backward` argument, it will always be False but it is required.
14
+ """
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ @torch.no_grad
19
+ def step(self, var):
20
+ closure = var.closure
21
+
22
+ if closure is not None:
23
+ def sos_closure(backward=True):
24
+ if backward:
25
+ var.zero_grad()
26
+ with torch.enable_grad():
27
+ loss = closure(False)
28
+ loss = loss.pow(2).sum()
29
+ loss.backward()
30
+ return loss
31
+
32
+ loss = closure(False)
33
+ return loss.pow(2).sum()
34
+
35
+ var.closure = sos_closure
36
+
37
+ if var.loss is not None:
38
+ var.loss = var.loss.pow(2).sum()
39
+
40
+ if var.loss_approx is not None:
41
+ var.loss_approx = var.loss_approx.pow(2).sum()
42
+
43
+ return var
44
+
45
+
46
+ class GaussNewton(Module):
47
+ """Gauss-newton method.
48
+
49
+ To use this, the closure should return a vector of values to minimize sum of squares of.
50
+ Please add the ``backward`` argument, it will always be False but it is required.
51
+ Gradients will be calculated via batched autograd within this module, you don't need to
52
+ implement the backward pass. Please see below for an example.
53
+
54
+ Note:
55
+ This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
56
+ the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.
57
+
58
+ Args:
59
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
60
+ batched (bool, optional): whether to use vmapping. Defaults to True.
61
+
62
+ Examples:
63
+
64
+ minimizing the rosenbrock function:
65
+ ```python
66
+ def rosenbrock(X):
67
+ x1, x2 = X
68
+ return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
69
+
70
+ X = torch.tensor([-1.1, 2.5], requires_grad=True)
71
+ opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
72
+
73
+ # define the closure for line search
74
+ def closure(backward=True):
75
+ return rosenbrock(X)
76
+
77
+ # minimize
78
+ for iter in range(10):
79
+ loss = opt.step(closure)
80
+ print(f'{loss = }')
81
+ ```
82
+
83
+ training a neural network with a matrix-free GN trust region:
84
+ ```python
85
+ X = torch.randn(64, 20)
86
+ y = torch.randn(64, 10)
87
+
88
+ model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
89
+ opt = tz.Modular(
90
+ model.parameters(),
91
+ tz.m.TrustCG(tz.m.GaussNewton()),
92
+ )
93
+
94
+ def closure(backward=True):
95
+ y_hat = model(X) # (64, 10)
96
+ return (y_hat - y).pow(2).mean(0) # (10, )
97
+
98
+ for i in range(100):
99
+ losses = opt.step(closure)
100
+ if i % 10 == 0:
101
+ print(f'{losses.mean() = }')
102
+ ```
103
+ """
104
+ def __init__(self, reg:float = 1e-8, batched:bool=True, ):
105
+ super().__init__(defaults=dict(batched=batched, reg=reg))
106
+
107
+ @torch.no_grad
108
+ def update(self, var):
109
+ params = var.params
110
+ batched = self.defaults['batched']
111
+
112
+ closure = var.closure
113
+ assert closure is not None
114
+
115
+ # gauss newton direction
116
+ with torch.enable_grad():
117
+ f = var.get_loss(backward=False) # n_out
118
+ assert isinstance(f, torch.Tensor)
119
+ G_list = jacobian_wrt([f.ravel()], params, batched=batched)
120
+
121
+ var.loss = f.pow(2).sum()
122
+
123
+ G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
124
+ Gtf = G.T @ f.detach() # (ndim)
125
+ self.global_state["Gtf"] = Gtf
126
+ var.grad = vec_to_tensors(Gtf, var.params)
127
+
128
+ # set closure to calculate sum of squares for line searches etc
129
+ if var.closure is not None:
130
+ def sos_closure(backward=True):
131
+ if backward:
132
+ var.zero_grad()
133
+ with torch.enable_grad():
134
+ loss = closure(False).pow(2).sum()
135
+ loss.backward()
136
+ return loss
137
+
138
+ loss = closure(False).pow(2).sum()
139
+ return loss
140
+
141
+ var.closure = sos_closure
142
+
143
+ @torch.no_grad
144
+ def apply(self, var):
145
+ reg = self.defaults['reg']
146
+
147
+ G = self.global_state['G']
148
+ Gtf = self.global_state['Gtf']
149
+
150
+ GtG = G.T @ G # (ndim, ndim)
151
+ if reg != 0:
152
+ GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
153
+
154
+ v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
155
+
156
+ var.update = vec_to_tensors(v, var.params)
157
+ return var
158
+
159
+ def get_H(self, var):
160
+ G = self.global_state['G']
161
+ return linear_operator.AtA(G)
@@ -1,5 +1,5 @@
1
- from .line_search import LineSearch, GridLineSearch
2
- from .backtracking import backtracking_line_search, Backtracking, AdaptiveBacktracking
3
- from .strong_wolfe import StrongWolfe
1
+ from .adaptive import AdaptiveTracking
2
+ from .backtracking import AdaptiveBacktracking, Backtracking
3
+ from .line_search import LineSearchBase
4
4
  from .scipy import ScipyMinimizeScalar
5
- from .trust_region import TrustRegion
5
+ from .strong_wolfe import StrongWolfe