torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -13,7 +13,7 @@ import torch
13
13
  from ...core import Chainable, Module, apply_transform
14
14
  from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
15
  from ...utils.derivatives import (
16
- hessian_list_to_mat,
16
+ flatten_jacobian,
17
17
  jacobian_wrt,
18
18
  )
19
19
 
@@ -148,21 +148,16 @@ class HigherOrderNewton(Module):
148
148
  """A basic arbitrary order newton's method with optional trust region and proximal penalty.
149
149
 
150
150
  This constructs an nth order taylor approximation via autograd and minimizes it with
151
- scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
151
+ ``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.
152
152
 
153
- .. note::
154
- In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
153
+ The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
154
+ so it can be more efficient in very specific instances.
155
155
 
156
- .. note::
157
- This module requires the a closure passed to the optimizer step,
158
- as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
159
- The closure must accept a ``backward`` argument (refer to documentation).
160
-
161
- .. warning::
162
- this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
163
-
164
- .. warning::
165
- "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
156
+ Notes:
157
+ - In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
158
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
159
+ - this uses roughly O(N^order) memory and solving the subproblem is very expensive.
160
+ - "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
166
161
 
167
162
  Args:
168
163
 
@@ -178,7 +173,7 @@ class HigherOrderNewton(Module):
178
173
  increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
179
174
  decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
180
175
  trust_init (float | None, optional):
181
- initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on :code:`"proximal"`. Defaults to None.
176
+ initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
182
177
  trust_tol (float, optional):
183
178
  Maximum ratio of expected loss reduction to actual reduction for trust region increase.
184
179
  Should 1 or higer. Defaults to 2.
@@ -191,11 +186,14 @@ class HigherOrderNewton(Module):
191
186
  self,
192
187
  order: int = 4,
193
188
  trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
194
- nplus: float = 2,
189
+ nplus: float = 3.5,
195
190
  nminus: float = 0.25,
191
+ rho_good: float = 0.99,
192
+ rho_bad: float = 1e-4,
196
193
  init: float | None = None,
197
194
  eta: float = 1e-6,
198
195
  max_attempts = 10,
196
+ boundary_tol: float = 1e-2,
199
197
  de_iters: int | None = None,
200
198
  vectorize: bool = True,
201
199
  ):
@@ -203,7 +201,7 @@ class HigherOrderNewton(Module):
203
201
  if trust_method == 'bounds': init = 1
204
202
  else: init = 0.1
205
203
 
206
- defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
204
+ defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
207
205
  super().__init__(defaults)
208
206
 
209
207
  @torch.no_grad
@@ -222,6 +220,9 @@ class HigherOrderNewton(Module):
222
220
  de_iters = settings['de_iters']
223
221
  max_attempts = settings['max_attempts']
224
222
  vectorize = settings['vectorize']
223
+ boundary_tol = settings['boundary_tol']
224
+ rho_good = settings['rho_good']
225
+ rho_bad = settings['rho_bad']
225
226
 
226
227
  # ------------------------ calculate grad and hessian ------------------------ #
227
228
  with torch.enable_grad():
@@ -241,7 +242,7 @@ class HigherOrderNewton(Module):
241
242
  T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
242
243
  with torch.no_grad() if is_last else nullcontext():
243
244
  # the shape is (ndim, ) * order
244
- T = hessian_list_to_mat(T_list).view(n, n, *T.shape[1:])
245
+ T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
245
246
  derivatives.append(T)
246
247
 
247
248
  x0 = torch.cat([p.ravel() for p in params])
@@ -254,8 +255,13 @@ class HigherOrderNewton(Module):
254
255
 
255
256
  # load trust region value
256
257
  trust_value = self.global_state.get('trust_region', init)
257
- if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
258
258
 
259
+ # make sure its not too small or too large
260
+ finfo = torch.finfo(x0.dtype)
261
+ if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
262
+ trust_value = self.global_state['trust_region'] = settings['init']
263
+
264
+ # determine tr and prox values
259
265
  if trust_method is None: trust_method = 'none'
260
266
  else: trust_method = trust_method.lower()
261
267
 
@@ -297,13 +303,15 @@ class HigherOrderNewton(Module):
297
303
 
298
304
  rho = reduction / (max(pred_reduction, 1e-8))
299
305
  # failed step
300
- if rho < 0.25:
306
+ if rho < rho_bad:
301
307
  self.global_state['trust_region'] = trust_value * nminus
302
308
 
303
309
  # very good step
304
- elif rho > 0.75:
305
- diff = trust_value - (x0 - x_star).abs_()
306
- if (diff.amin() / trust_value) > 1e-4: # hits boundary
310
+ elif rho > rho_good:
311
+ step = (x_star - x0)
312
+ magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
313
+ if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
314
+ # close to boundary
307
315
  self.global_state['trust_region'] = trust_value * nplus
308
316
 
309
317
  # if the ratio is high enough then accept the proposed step
@@ -0,0 +1 @@
1
+ from .gn import SumOfSquares, GaussNewton
@@ -0,0 +1,161 @@
1
+ import torch
2
+ from ...core import Module
3
+
4
+ from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
+ from ...utils import vec_to_tensors
6
+ from ...utils.linalg import linear_operator
7
+ class SumOfSquares(Module):
8
+ """Sets loss to be the sum of squares of values returned by the closure.
9
+
10
+ This is meant to be used to test least squares methods against ordinary minimization methods.
11
+
12
+ To use this, the closure should return a vector of values to minimize sum of squares of.
13
+ Please add the `backward` argument, it will always be False but it is required.
14
+ """
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ @torch.no_grad
19
+ def step(self, var):
20
+ closure = var.closure
21
+
22
+ if closure is not None:
23
+ def sos_closure(backward=True):
24
+ if backward:
25
+ var.zero_grad()
26
+ with torch.enable_grad():
27
+ loss = closure(False)
28
+ loss = loss.pow(2).sum()
29
+ loss.backward()
30
+ return loss
31
+
32
+ loss = closure(False)
33
+ return loss.pow(2).sum()
34
+
35
+ var.closure = sos_closure
36
+
37
+ if var.loss is not None:
38
+ var.loss = var.loss.pow(2).sum()
39
+
40
+ if var.loss_approx is not None:
41
+ var.loss_approx = var.loss_approx.pow(2).sum()
42
+
43
+ return var
44
+
45
+
46
+ class GaussNewton(Module):
47
+ """Gauss-newton method.
48
+
49
+ To use this, the closure should return a vector of values to minimize sum of squares of.
50
+ Please add the ``backward`` argument, it will always be False but it is required.
51
+ Gradients will be calculated via batched autograd within this module, you don't need to
52
+ implement the backward pass. Please see below for an example.
53
+
54
+ Note:
55
+ This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
56
+ the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.
57
+
58
+ Args:
59
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
60
+ batched (bool, optional): whether to use vmapping. Defaults to True.
61
+
62
+ Examples:
63
+
64
+ minimizing the rosenbrock function:
65
+ ```python
66
+ def rosenbrock(X):
67
+ x1, x2 = X
68
+ return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
69
+
70
+ X = torch.tensor([-1.1, 2.5], requires_grad=True)
71
+ opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
72
+
73
+ # define the closure for line search
74
+ def closure(backward=True):
75
+ return rosenbrock(X)
76
+
77
+ # minimize
78
+ for iter in range(10):
79
+ loss = opt.step(closure)
80
+ print(f'{loss = }')
81
+ ```
82
+
83
+ training a neural network with a matrix-free GN trust region:
84
+ ```python
85
+ X = torch.randn(64, 20)
86
+ y = torch.randn(64, 10)
87
+
88
+ model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
89
+ opt = tz.Modular(
90
+ model.parameters(),
91
+ tz.m.TrustCG(tz.m.GaussNewton()),
92
+ )
93
+
94
+ def closure(backward=True):
95
+ y_hat = model(X) # (64, 10)
96
+ return (y_hat - y).pow(2).mean(0) # (10, )
97
+
98
+ for i in range(100):
99
+ losses = opt.step(closure)
100
+ if i % 10 == 0:
101
+ print(f'{losses.mean() = }')
102
+ ```
103
+ """
104
+ def __init__(self, reg:float = 1e-8, batched:bool=True, ):
105
+ super().__init__(defaults=dict(batched=batched, reg=reg))
106
+
107
+ @torch.no_grad
108
+ def update(self, var):
109
+ params = var.params
110
+ batched = self.defaults['batched']
111
+
112
+ closure = var.closure
113
+ assert closure is not None
114
+
115
+ # gauss newton direction
116
+ with torch.enable_grad():
117
+ f = var.get_loss(backward=False) # n_out
118
+ assert isinstance(f, torch.Tensor)
119
+ G_list = jacobian_wrt([f.ravel()], params, batched=batched)
120
+
121
+ var.loss = f.pow(2).sum()
122
+
123
+ G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
124
+ Gtf = G.T @ f.detach() # (ndim)
125
+ self.global_state["Gtf"] = Gtf
126
+ var.grad = vec_to_tensors(Gtf, var.params)
127
+
128
+ # set closure to calculate sum of squares for line searches etc
129
+ if var.closure is not None:
130
+ def sos_closure(backward=True):
131
+ if backward:
132
+ var.zero_grad()
133
+ with torch.enable_grad():
134
+ loss = closure(False).pow(2).sum()
135
+ loss.backward()
136
+ return loss
137
+
138
+ loss = closure(False).pow(2).sum()
139
+ return loss
140
+
141
+ var.closure = sos_closure
142
+
143
+ @torch.no_grad
144
+ def apply(self, var):
145
+ reg = self.defaults['reg']
146
+
147
+ G = self.global_state['G']
148
+ Gtf = self.global_state['Gtf']
149
+
150
+ GtG = G.T @ G # (ndim, ndim)
151
+ if reg != 0:
152
+ GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
153
+
154
+ v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
155
+
156
+ var.update = vec_to_tensors(v, var.params)
157
+ return var
158
+
159
+ def get_H(self, var):
160
+ G = self.global_state['G']
161
+ return linear_operator.AtA(G)
@@ -1,5 +1,5 @@
1
- from .adaptive import AdaptiveLineSearch
1
+ from .adaptive import AdaptiveTracking
2
2
  from .backtracking import AdaptiveBacktracking, Backtracking
3
3
  from .line_search import LineSearchBase
4
4
  from .scipy import ScipyMinimizeScalar
5
- from .strong_wolfe import StrongWolfe
5
+ from .strong_wolfe import StrongWolfe
@@ -0,0 +1,289 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .line_search import LineSearchBase
5
+
6
+
7
+ # polynomial interpolation
8
+ # this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
9
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
10
+ def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
11
+ """
12
+ Gives the minimizer and minimum of the interpolating polynomial over given points
13
+ based on function and derivative information. Defaults to bisection if no critical
14
+ points are valid.
15
+
16
+ Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
17
+ modifications.
18
+
19
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
20
+ Last edited 12/6/18.
21
+
22
+ Inputs:
23
+ points (nparray): two-dimensional array with each point of form [x f g]
24
+ x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
25
+ x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
26
+ plot (bool): plot interpolating polynomial
27
+
28
+ Outputs:
29
+ x_sol (float): minimizer of interpolating polynomial
30
+ F_min (float): minimum of interpolating polynomial
31
+
32
+ Note:
33
+ . Set f or g to np.nan if they are unknown
34
+
35
+ """
36
+ no_points = points.shape[0]
37
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
38
+
39
+ x_min = np.min(points[:, 0])
40
+ x_max = np.max(points[:, 0])
41
+
42
+ # compute bounds of interpolation area
43
+ if x_min_bound is None:
44
+ x_min_bound = x_min
45
+ if x_max_bound is None:
46
+ x_max_bound = x_max
47
+
48
+ # explicit formula for quadratic interpolation
49
+ if no_points == 2 and order == 2 and plot is False:
50
+ # Solution to quadratic interpolation is given by:
51
+ # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
52
+ # x_min = x1 - g1/(2a)
53
+ # if x1 = 0, then is given by:
54
+ # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
55
+
56
+ if points[0, 0] == 0:
57
+ x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
58
+ else:
59
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
60
+ x_sol = points[0, 0] - points[0, 2]/(2*a)
61
+
62
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
63
+
64
+ # explicit formula for cubic interpolation
65
+ elif no_points == 2 and order == 3 and plot is False:
66
+ # Solution to cubic interpolation is given by:
67
+ # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
68
+ # d2 = sqrt(d1^2 - g1*g2)
69
+ # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
70
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
71
+ value = d1 ** 2 - points[0, 2] * points[1, 2]
72
+ if value > 0:
73
+ d2 = np.sqrt(value)
74
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
75
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
76
+ else:
77
+ x_sol = (x_max_bound + x_min_bound)/2
78
+
79
+ # solve linear system
80
+ else:
81
+ # define linear constraints
82
+ A = np.zeros((0, order + 1))
83
+ b = np.zeros((0, 1))
84
+
85
+ # add linear constraints on function values
86
+ for i in range(no_points):
87
+ if not np.isnan(points[i, 1]):
88
+ constraint = np.zeros((1, order + 1))
89
+ for j in range(order, -1, -1):
90
+ constraint[0, order - j] = points[i, 0] ** j
91
+ A = np.append(A, constraint, 0)
92
+ b = np.append(b, points[i, 1])
93
+
94
+ # add linear constraints on gradient values
95
+ for i in range(no_points):
96
+ if not np.isnan(points[i, 2]):
97
+ constraint = np.zeros((1, order + 1))
98
+ for j in range(order):
99
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
100
+ A = np.append(A, constraint, 0)
101
+ b = np.append(b, points[i, 2])
102
+
103
+ # check if system is solvable
104
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
105
+ x_sol = (x_min_bound + x_max_bound)/2
106
+ f_min = np.inf
107
+ else:
108
+ # solve linear system for interpolating polynomial
109
+ coeff = np.linalg.solve(A, b)
110
+
111
+ # compute critical points
112
+ dcoeff = np.zeros(order)
113
+ for i in range(len(coeff) - 1):
114
+ dcoeff[i] = coeff[i] * (order - i)
115
+
116
+ crit_pts = np.array([x_min_bound, x_max_bound])
117
+ crit_pts = np.append(crit_pts, points[:, 0])
118
+
119
+ if not np.isinf(dcoeff).any():
120
+ roots = np.roots(dcoeff)
121
+ crit_pts = np.append(crit_pts, roots)
122
+
123
+ # test critical points
124
+ f_min = np.inf
125
+ x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
126
+ for crit_pt in crit_pts:
127
+ if np.isreal(crit_pt):
128
+ if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
129
+ if crit_pt >= x_min_bound and crit_pt <= x_max_bound:
130
+ F_cp = np.polyval(coeff, crit_pt)
131
+ if np.isreal(F_cp) and F_cp < f_min:
132
+ x_sol = np.real(crit_pt)
133
+ f_min = np.real(F_cp)
134
+
135
+ if(plot):
136
+ import matplotlib.pyplot as plt
137
+ plt.figure()
138
+ x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
139
+ f = np.polyval(coeff, x)
140
+ plt.plot(x, f)
141
+ plt.plot(x_sol, f_min, 'x')
142
+
143
+ return x_sol
144
+
145
+
146
+ # polynomial interpolation
147
+ # this code is based on https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
148
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
149
+ # this one is modified where instead of clipping the solution by bounds, it tries a lower degree polynomial
150
+ # all the way to bisection
151
+ def _within_bounds(x, lb, ub):
152
+ if lb is not None and x < lb: return False
153
+ if ub is not None and x > ub: return False
154
+ return True
155
+
156
+ def _quad_interp(points):
157
+ assert points.shape[0] == 2, points.shape
158
+ if points[0, 0] == 0:
159
+ denom = 2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0])
160
+ if abs(denom) > 1e-32:
161
+ return -points[0, 2] * points[1, 0] ** 2 / denom
162
+ else:
163
+ denom = (points[0, 0] - points[1, 0]) ** 2
164
+ if denom > 1e-32:
165
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / denom
166
+ if a > 1e-32:
167
+ return points[0, 0] - points[0, 2]/(2*a)
168
+ return None
169
+
170
+ def _cubic_interp(points, lb, ub):
171
+ assert points.shape[0] == 2, points.shape
172
+ denom = points[0, 0] - points[1, 0]
173
+ if abs(denom) > 1e-32:
174
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / denom)
175
+ value = d1 ** 2 - points[0, 2] * points[1, 2]
176
+ if value > 0:
177
+ d2 = np.sqrt(value)
178
+ denom = points[1, 2] - points[0, 2] + 2 * d2
179
+ if abs(denom) > 1e-32:
180
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / denom)
181
+ if _within_bounds(x_sol, lb, ub): return x_sol
182
+
183
+ # try quadratic interpolations
184
+ x_sol = _quad_interp(points)
185
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
186
+
187
+ return None
188
+
189
+ def _poly_interp(points, lb, ub):
190
+ no_points = points.shape[0]
191
+ assert no_points > 2, points.shape
192
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
193
+
194
+ # define linear constraints
195
+ A = np.zeros((0, order + 1))
196
+ b = np.zeros((0, 1))
197
+
198
+ # add linear constraints on function values
199
+ for i in range(no_points):
200
+ if not np.isnan(points[i, 1]):
201
+ constraint = np.zeros((1, order + 1))
202
+ for j in range(order, -1, -1):
203
+ constraint[0, order - j] = points[i, 0] ** j
204
+ A = np.append(A, constraint, 0)
205
+ b = np.append(b, points[i, 1])
206
+
207
+ # add linear constraints on gradient values
208
+ for i in range(no_points):
209
+ if not np.isnan(points[i, 2]):
210
+ constraint = np.zeros((1, order + 1))
211
+ for j in range(order):
212
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
213
+ A = np.append(A, constraint, 0)
214
+ b = np.append(b, points[i, 2])
215
+
216
+ # check if system is solvable
217
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
218
+ return None
219
+
220
+ # solve linear system for interpolating polynomial
221
+ coeff = np.linalg.solve(A, b)
222
+
223
+ # compute critical points
224
+ dcoeff = np.zeros(order)
225
+ for i in range(len(coeff) - 1):
226
+ dcoeff[i] = coeff[i] * (order - i)
227
+
228
+ lower = np.min(points[:, 0]) if lb is None else lb
229
+ upper = np.max(points[:, 0]) if ub is None else ub
230
+
231
+ crit_pts = np.array([lower, upper])
232
+ crit_pts = np.append(crit_pts, points[:, 0])
233
+
234
+ if not np.isinf(dcoeff).any():
235
+ roots = np.roots(dcoeff)
236
+ crit_pts = np.append(crit_pts, roots)
237
+
238
+ # test critical points
239
+ f_min = np.inf
240
+ x_sol = None
241
+ for crit_pt in crit_pts:
242
+ if np.isreal(crit_pt):
243
+ if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
244
+ if _within_bounds(crit_pt, lb, ub):
245
+ F_cp = np.polyval(coeff, crit_pt)
246
+ if np.isreal(F_cp) and F_cp < f_min:
247
+ x_sol = np.real(crit_pt)
248
+ f_min = np.real(F_cp)
249
+
250
+ return x_sol
251
+
252
+ def polyinterp2(points, lb, ub, unbounded: bool = False):
253
+ no_points = points.shape[0]
254
+ if no_points <= 1:
255
+ return (lb + ub)/2
256
+
257
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
258
+
259
+ x_min = np.min(points[:, 0])
260
+ x_max = np.max(points[:, 0])
261
+
262
+ # compute bounds of interpolation area
263
+ if not unbounded:
264
+ if lb is None:
265
+ lb = x_min
266
+ if ub is None:
267
+ ub = x_max
268
+
269
+ if no_points == 2 and order == 2:
270
+ x_sol = _quad_interp(points)
271
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
272
+ return (lb + ub)/2
273
+
274
+ if no_points == 2 and order == 3:
275
+ x_sol = _cubic_interp(points, lb, ub) # includes fallback on _quad_interp
276
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
277
+ return (lb + ub)/2
278
+
279
+ if no_points <= 2: # order < 2
280
+ return (lb + ub)/2
281
+
282
+ if no_points == 3:
283
+ for p in (points[:2], points[1:], points[::2]):
284
+ x_sol = _cubic_interp(p, lb, ub)
285
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
286
+
287
+ x_sol = _poly_interp(points, lb, ub)
288
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
289
+ return polyinterp2(points[1:], lb, ub)