torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -70,57 +70,99 @@ def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
70
70
  def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
71
71
  derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
72
72
  x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
73
- bounds = None
74
- if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
75
73
 
76
- # if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
77
- if bounds is None:
78
- if len(derivatives) == 1: method = 'bfgs'
79
- else: method = 'trust-exact'
74
+ # notes
75
+ # 1. since we have exact hessian we use trust methods
76
+
77
+ # 2. if len(derivatives) is 1, only gradient is available,
78
+ # thus use slsqp depending on whether trust region is enabled
79
+ # this is just so that I can test that trust region works
80
+ if trust_region is None:
81
+ if len(derivatives) == 1: raise RuntimeError("trust region must be enabled because 1st order has no minima")
82
+ method = 'trust-exact'
83
+ de_bounds = list(zip(x0 - 10, x0 + 10))
84
+ constraints = None
85
+
80
86
  else:
81
- if len(derivatives) == 1: method = 'l-bfgs-b'
87
+ if len(derivatives) == 1: method = 'slsqp'
82
88
  else: method = 'trust-constr'
89
+ de_bounds = list(zip(x0 - trust_region, x0 + trust_region))
90
+
91
+ def l2_bound_f(x):
92
+ if x.ndim == 2: return np.sum((x - x0[:,None])**2, axis=0)[None,:] # DE passes (ndim, batch_size) and expects (M, S)
93
+ return np.sum((x - x0)**2, axis=0)
94
+
95
+ def l2_bound_g(x):
96
+ return 2 * (x - x0)
97
+
98
+ def l2_bound_h(x, v):
99
+ return v[0] * 2 * np.eye(x0.shape[0])
100
+
101
+ constraint = scipy.optimize.NonlinearConstraint(
102
+ fun=l2_bound_f,
103
+ lb=0, # 0 <= ||x-x0||^2
104
+ ub=trust_region**2, # ||x-x0||^2 <= R^2
105
+ jac=l2_bound_g, # pyright:ignore[reportArgumentType]
106
+ hess=l2_bound_h,
107
+ keep_feasible=False
108
+ )
109
+ constraints = [constraint]
83
110
 
84
111
  x_init = x0.copy()
85
112
  v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
113
+
114
+ # ---------------------------------- run DE ---------------------------------- #
86
115
  if de_iters is not None and de_iters != 0:
87
116
  if de_iters == -1: de_iters = None # let scipy decide
117
+
118
+ # DE needs bounds so use linf ig
88
119
  res = scipy.optimize.differential_evolution(
89
120
  _proximal_poly_v,
90
- bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
121
+ de_bounds,
91
122
  args=(c, prox, x0.copy(), derivatives),
92
123
  maxiter=de_iters,
93
124
  vectorized=True,
125
+ constraints = constraints,
126
+ updating='deferred',
94
127
  )
95
- if res.fun < v0: x_init = res.x
96
-
97
- res = scipy.optimize.minimize(
98
- _proximal_poly_v,
99
- x_init,
100
- method=method,
101
- args=(c, prox, x0.copy(), derivatives),
102
- jac=_proximal_poly_g,
103
- hess=_proximal_poly_H,
104
- bounds=bounds
105
- )
128
+ if res.fun < v0 and np.all(np.isfinite(res.x)): x_init = res.x
106
129
 
130
+ # ------------------------------- run minimize ------------------------------- #
131
+ try:
132
+ res = scipy.optimize.minimize(
133
+ _proximal_poly_v,
134
+ x_init,
135
+ method=method,
136
+ args=(c, prox, x0.copy(), derivatives),
137
+ jac=_proximal_poly_g,
138
+ hess=_proximal_poly_H,
139
+ constraints = constraints,
140
+ )
141
+ except ValueError:
142
+ return x, -float('inf')
107
143
  return torch.from_numpy(res.x).to(x), res.fun
108
144
 
109
145
 
110
146
 
111
147
  class HigherOrderNewton(Module):
112
- """
113
- A basic arbitrary order newton's method with optional trust region and proximal penalty.
114
- It is recommended to enable at least one of trust region or proximal penalty.
148
+ """A basic arbitrary order newton's method with optional trust region and proximal penalty.
115
149
 
116
150
  This constructs an nth order taylor approximation via autograd and minimizes it with
117
151
  scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
118
152
 
119
- This uses n^order memory, where n is number of decision variables, and I am not aware
120
- of any problems where this is more efficient than newton's method. It can minimize
121
- rosenbrock in a single step, but that step probably takes more time than newton.
122
- And there are way more efficient tensor methods out there but they tend to be
123
- significantly more complex.
153
+ .. note::
154
+ In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
155
+
156
+ .. note::
157
+ This module requires the a closure passed to the optimizer step,
158
+ as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
159
+ The closure must accept a ``backward`` argument (refer to documentation).
160
+
161
+ .. warning::
162
+ this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
163
+
164
+ .. warning::
165
+ "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
124
166
 
125
167
  Args:
126
168
 
@@ -149,39 +191,38 @@ class HigherOrderNewton(Module):
149
191
  self,
150
192
  order: int = 4,
151
193
  trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
152
- increase: float = 1.5,
153
- decrease: float = 0.75,
154
- trust_init: float | None = None,
155
- trust_tol: float = 2,
194
+ nplus: float = 2,
195
+ nminus: float = 0.25,
196
+ init: float | None = None,
197
+ eta: float = 1e-6,
198
+ max_attempts = 10,
156
199
  de_iters: int | None = None,
157
200
  vectorize: bool = True,
158
201
  ):
159
- if trust_init is None:
160
- if trust_method == 'bounds': trust_init = 1
161
- else: trust_init = 0.1
202
+ if init is None:
203
+ if trust_method == 'bounds': init = 1
204
+ else: init = 0.1
162
205
 
163
- defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
206
+ defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
164
207
  super().__init__(defaults)
165
208
 
166
209
  @torch.no_grad
167
210
  def step(self, var):
168
211
  params = TensorList(var.params)
169
212
  closure = var.closure
170
- if closure is None: raise RuntimeError('NewtonCG requires closure')
213
+ if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
171
214
 
172
215
  settings = self.settings[params[0]]
173
216
  order = settings['order']
174
- increase = settings['increase']
175
- decrease = settings['decrease']
176
- trust_tol = settings['trust_tol']
177
- trust_init = settings['trust_init']
217
+ nplus = settings['nplus']
218
+ nminus = settings['nminus']
219
+ eta = settings['eta']
220
+ init = settings['init']
178
221
  trust_method = settings['trust_method']
179
222
  de_iters = settings['de_iters']
223
+ max_attempts = settings['max_attempts']
180
224
  vectorize = settings['vectorize']
181
225
 
182
- trust_value = self.global_state.get('trust_value', trust_init)
183
-
184
-
185
226
  # ------------------------ calculate grad and hessian ------------------------ #
186
227
  with torch.enable_grad():
187
228
  loss = var.loss = var.loss_approx = closure(False)
@@ -205,52 +246,74 @@ class HigherOrderNewton(Module):
205
246
 
206
247
  x0 = torch.cat([p.ravel() for p in params])
207
248
 
208
- if trust_method is None: trust_method = 'none'
209
- else: trust_method = trust_method.lower()
210
-
211
- if trust_method == 'none':
212
- trust_region = None
213
- prox = 0
214
-
215
- elif trust_method == 'bounds':
216
- trust_region = trust_value
217
- prox = 0
218
-
219
- elif trust_method == 'proximal':
220
- trust_region = None
221
- prox = 1 / trust_value
222
-
249
+ success = False
250
+ x_star = None
251
+ while not success:
252
+ max_attempts -= 1
253
+ if max_attempts < 0: break
254
+
255
+ # load trust region value
256
+ trust_value = self.global_state.get('trust_region', init)
257
+ if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
258
+
259
+ if trust_method is None: trust_method = 'none'
260
+ else: trust_method = trust_method.lower()
261
+
262
+ if trust_method == 'none':
263
+ trust_region = None
264
+ prox = 0
265
+
266
+ elif trust_method == 'bounds':
267
+ trust_region = trust_value
268
+ prox = 0
269
+
270
+ elif trust_method == 'proximal':
271
+ trust_region = None
272
+ prox = 1 / trust_value
273
+
274
+ else:
275
+ raise ValueError(trust_method)
276
+
277
+ # minimize the model
278
+ x_star, expected_loss = _poly_minimize(
279
+ trust_region=trust_region,
280
+ prox=prox,
281
+ de_iters=de_iters,
282
+ c=loss.item(),
283
+ x=x0,
284
+ derivatives=derivatives,
285
+ )
286
+
287
+ # update trust region
288
+ if trust_method == 'none':
289
+ success = True
290
+ else:
291
+ pred_reduction = loss - expected_loss
292
+
293
+ vec_to_tensors_(x_star, params)
294
+ loss_star = closure(False)
295
+ vec_to_tensors_(x0, params)
296
+ reduction = loss - loss_star
297
+
298
+ rho = reduction / (max(pred_reduction, 1e-8))
299
+ # failed step
300
+ if rho < 0.25:
301
+ self.global_state['trust_region'] = trust_value * nminus
302
+
303
+ # very good step
304
+ elif rho > 0.75:
305
+ diff = trust_value - (x0 - x_star).abs_()
306
+ if (diff.amin() / trust_value) > 1e-4: # hits boundary
307
+ self.global_state['trust_region'] = trust_value * nplus
308
+
309
+ # if the ratio is high enough then accept the proposed step
310
+ success = rho > eta
311
+
312
+ assert x_star is not None
313
+ if success:
314
+ difference = vec_to_tensors(x0 - x_star, params)
315
+ var.update = list(difference)
223
316
  else:
224
- raise ValueError(trust_method)
225
-
226
- x_star, expected_loss = _poly_minimize(
227
- trust_region=trust_region,
228
- prox=prox,
229
- de_iters=de_iters,
230
- c=loss.item(),
231
- x=x0,
232
- derivatives=derivatives,
233
- )
234
-
235
- # trust region
236
- if trust_method != 'none':
237
- expected_reduction = loss - expected_loss
238
-
239
- vec_to_tensors_(x_star, params)
240
- loss_star = closure(False)
241
- vec_to_tensors_(x0, params)
242
- reduction = loss - loss_star
243
-
244
- # failed step
245
- if reduction <= 0:
246
- x_star = x0
247
- self.global_state['trust_value'] = trust_value * decrease
248
-
249
- # very good step
250
- elif expected_reduction / reduction <= trust_tol:
251
- self.global_state['trust_value'] = trust_value * increase
252
-
253
- difference = vec_to_tensors(x0 - x_star, params)
254
- var.update = list(difference)
317
+ var.update = params.zeros_like()
255
318
  return var
256
319
 
@@ -1,5 +1,5 @@
1
- from .line_search import LineSearch, GridLineSearch
2
- from .backtracking import backtracking_line_search, Backtracking, AdaptiveBacktracking
3
- from .strong_wolfe import StrongWolfe
1
+ from .adaptive import AdaptiveLineSearch
2
+ from .backtracking import AdaptiveBacktracking, Backtracking
3
+ from .line_search import LineSearchBase
4
4
  from .scipy import ScipyMinimizeScalar
5
- from .trust_region import TrustRegion
5
+ from .strong_wolfe import StrongWolfe
@@ -0,0 +1,99 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from operator import itemgetter
4
+
5
+ import torch
6
+
7
+ from .line_search import LineSearchBase
8
+
9
+
10
+
11
+ def adaptive_tracking(
12
+ f,
13
+ x_0,
14
+ maxiter: int,
15
+ nplus: float = 2,
16
+ nminus: float = 0.5,
17
+ ):
18
+ f_0 = f(0)
19
+
20
+ t = x_0
21
+ f_t = f(t)
22
+
23
+ # backtrack
24
+ if f_t > f_0:
25
+ while f_t > f_0:
26
+ maxiter -= 1
27
+ if maxiter < 0: return 0, f_0
28
+ t = t*nminus
29
+ f_t = f(t)
30
+ return t, f_t
31
+
32
+ # forwardtrack
33
+ f_prev = f_t
34
+ t *= nplus
35
+ f_t = f(t)
36
+ if f_prev < f_t: return t / nplus, f_prev
37
+ while f_prev >= f_t:
38
+ maxiter -= 1
39
+ if maxiter < 0: return t, f_t
40
+ f_prev = f_t
41
+ t *= nplus
42
+ f_t = f(t)
43
+ return t / nplus, f_prev
44
+
45
+ class AdaptiveLineSearch(LineSearchBase):
46
+ """Adaptive line search, similar to backtracking but also has forward tracking mode.
47
+ Currently doesn't check for weak curvature condition.
48
+
49
+ Args:
50
+ init (float, optional): initial step size. Defaults to 1.0.
51
+ beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
52
+ maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
53
+ adaptive (bool, optional):
54
+ when enabled, if line search failed, beta size is reduced.
55
+ Otherwise it is reset to initial value. Defaults to True.
56
+ """
57
+ def __init__(
58
+ self,
59
+ init: float = 1.0,
60
+ nplus: float = 2,
61
+ nminus: float = 0.5,
62
+ maxiter: int = 10,
63
+ adaptive=True,
64
+ ):
65
+ defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive,)
66
+ super().__init__(defaults=defaults)
67
+ self.global_state['beta_scale'] = 1.0
68
+
69
+ def reset(self):
70
+ super().reset()
71
+ self.global_state['beta_scale'] = 1.0
72
+
73
+ @torch.no_grad
74
+ def search(self, update, var):
75
+ init, nplus, nminus, maxiter, adaptive = itemgetter(
76
+ 'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.settings[var.params[0]])
77
+
78
+ objective = self.make_objective(var=var)
79
+
80
+ # # directional derivative
81
+ # d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
82
+
83
+ # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
84
+ beta_scale = self.global_state.get('beta_scale', 1)
85
+ x_prev = self.global_state.get('prev_x', 1)
86
+
87
+ if adaptive: nminus = nminus * beta_scale
88
+
89
+
90
+ step_size, f = adaptive_tracking(objective, x_prev, maxiter, nplus=nplus, nminus=nminus)
91
+
92
+ # found an alpha that reduces loss
93
+ if step_size != 0:
94
+ self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
95
+ return step_size
96
+
97
+ # on fail reduce beta scale value
98
+ self.global_state['beta_scale'] /= 1.5
99
+ return 0
@@ -4,7 +4,7 @@ from operator import itemgetter
4
4
 
5
5
  import torch
6
6
 
7
- from .line_search import LineSearch
7
+ from .line_search import LineSearchBase
8
8
 
9
9
 
10
10
  def backtracking_line_search(
@@ -19,12 +19,12 @@ def backtracking_line_search(
19
19
  """
20
20
 
21
21
  Args:
22
- objective_fn: evaluates step size along some descent direction.
23
- dir_derivative: directional derivative along the descent direction.
24
- alpha_init: initial step size.
22
+ f: evaluates step size along some descent direction.
23
+ g_0: directional derivative along the descent direction.
24
+ init: initial step size.
25
25
  beta: The factor by which to decrease alpha in each iteration
26
26
  c: The constant for the Armijo sufficient decrease condition
27
- max_iter: Maximum number of backtracking iterations (default: 10).
27
+ maxiter: Maximum number of backtracking iterations (default: 10).
28
28
 
29
29
  Returns:
30
30
  step size
@@ -32,11 +32,15 @@ def backtracking_line_search(
32
32
 
33
33
  a = init
34
34
  f_x = f(0)
35
+ f_prev = None
35
36
 
36
37
  for iteration in range(maxiter):
37
38
  f_a = f(a)
38
39
 
39
- if f_a <= f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
40
+ if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_x): return a / beta
41
+ f_prev = f_a
42
+
43
+ if f_a < f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
40
44
  # found an acceptable alpha
41
45
  return a
42
46
 
@@ -59,7 +63,7 @@ def backtracking_line_search(
59
63
 
60
64
  return None
61
65
 
62
- class Backtracking(LineSearch):
66
+ class Backtracking(LineSearchBase):
63
67
  """Backtracking line search satisfying the Armijo condition.
64
68
 
65
69
  Args:
@@ -68,9 +72,30 @@ class Backtracking(LineSearch):
68
72
  c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
69
73
  maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
70
74
  adaptive (bool, optional):
71
- when enabled, if line search failed, initial step size is reduced.
75
+ when enabled, if line search failed, beta is reduced.
72
76
  Otherwise it is reset to initial value. Defaults to True.
73
77
  try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
78
+
79
+ Examples:
80
+ Gradient descent with backtracking line search:
81
+
82
+ .. code-block:: python
83
+
84
+ opt = tz.Modular(
85
+ model.parameters(),
86
+ tz.m.Backtracking()
87
+ )
88
+
89
+ LBFGS with backtracking line search:
90
+
91
+ .. code-block:: python
92
+
93
+ opt = tz.Modular(
94
+ model.parameters(),
95
+ tz.m.LBFGS(),
96
+ tz.m.Backtracking()
97
+ )
98
+
74
99
  """
75
100
  def __init__(
76
101
  self,
@@ -117,7 +142,7 @@ class Backtracking(LineSearch):
117
142
  def _lerp(start,end,weight):
118
143
  return start + weight * (end - start)
119
144
 
120
- class AdaptiveBacktracking(LineSearch):
145
+ class AdaptiveBacktracking(LineSearchBase):
121
146
  """Adaptive backtracking line search. After each line search procedure, a new initial step size is set
122
147
  such that optimal step size in the procedure would be found on the second line search iteration.
123
148
 
@@ -15,8 +15,9 @@ from ...utils import tofloat
15
15
  class MaxLineSearchItersReached(Exception): pass
16
16
 
17
17
 
18
- class LineSearch(Module, ABC):
18
+ class LineSearchBase(Module, ABC):
19
19
  """Base class for line searches.
20
+
20
21
  This is an abstract class, to use it, subclass it and override `search`.
21
22
 
22
23
  Args:
@@ -26,6 +27,62 @@ class LineSearch(Module, ABC):
26
27
  the objective this many times, and step size with the lowest loss value will be used.
27
28
  This is useful when passing `make_objective` to an external library which
28
29
  doesn't have a maxiter option. Defaults to None.
30
+
31
+ Other useful methods:
32
+ * `evaluate_step_size` - returns loss with a given scalar step size
33
+ * `evaluate_step_size_loss_and_derivative` - returns loss and directional derivative with a given scalar step size
34
+ * `make_objective` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
35
+ * `make_objective_with_derivative` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
36
+
37
+ Examples:
38
+ #### Basic line search
39
+
40
+ This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
41
+
42
+ .. code-block:: python
43
+
44
+ class GridLineSearch(LineSearch):
45
+ def __init__(self, start, end, num):
46
+ defaults = dict(start=start,end=end,num=num)
47
+ super().__init__(defaults)
48
+
49
+ @torch.no_grad
50
+ def search(self, update, var):
51
+ settings = self.settings[var.params[0]]
52
+ start = settings["start"]
53
+ end = settings["end"]
54
+ num = settings["num"]
55
+
56
+ lowest_loss = float("inf")
57
+ best_step_size = best_step_size
58
+
59
+ for step_size in torch.linspace(start,end,num):
60
+ loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
61
+ if loss < lowest_loss:
62
+ lowest_loss = loss
63
+ best_step_size = step_size
64
+
65
+ return best_step_size
66
+
67
+ #### Using external solver via self.make_objective
68
+
69
+ Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
70
+
71
+ .. code-block:: python
72
+
73
+ class ScipyMinimizeScalar(LineSearch):
74
+ def __init__(self, method: str | None = None):
75
+ defaults = dict(method=method)
76
+ super().__init__(defaults)
77
+
78
+ @torch.no_grad
79
+ def search(self, update, var):
80
+ objective = self.make_objective(var=var)
81
+ method = self.settings[var.params[0]]["method"]
82
+
83
+ res = self.scopt.minimize_scalar(objective, method=method)
84
+ return res.x
85
+
29
86
  """
30
87
  def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
31
88
  super().__init__(defaults)
@@ -165,17 +222,18 @@ class LineSearch(Module, ABC):
165
222
  return var
166
223
 
167
224
 
168
- class GridLineSearch(LineSearch):
169
- """Mostly for testing, this is not practical"""
170
- def __init__(self, start, end, num):
171
- defaults = dict(start=start,end=end,num=num)
172
- super().__init__(defaults)
173
225
 
174
- @torch.no_grad
175
- def search(self, update, var):
176
- start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
226
+ # class GridLineSearch(LineSearch):
227
+ # """Mostly for testing, this is not practical"""
228
+ # def __init__(self, start, end, num):
229
+ # defaults = dict(start=start,end=end,num=num)
230
+ # super().__init__(defaults)
231
+
232
+ # @torch.no_grad
233
+ # def search(self, update, var):
234
+ # start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
177
235
 
178
- for lr in torch.linspace(start,end,num):
179
- self.evaluate_step_size(lr.item(), var=var, backward=False)
236
+ # for lr in torch.linspace(start,end,num):
237
+ # self.evaluate_step_size(lr.item(), var=var, backward=False)
180
238
 
181
- return self._best_step_size
239
+ # return self._best_step_size