torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Vars
8
+ from ...core import Module, Var
9
9
 
10
10
  GradTarget = Literal['update', 'grad', 'closure']
11
11
  _Scalar = torch.Tensor | float
@@ -17,50 +17,50 @@ class GradApproximator(Module, ABC):
17
17
  Args:
18
18
  defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
19
19
  target (str, optional):
20
- whether to set `vars.grad`, `vars.update` or 'vars.closure`. Defaults to 'closure'.
20
+ whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.
21
21
  """
22
22
  def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
23
23
  super().__init__(defaults)
24
24
  self._target: GradTarget = target
25
25
 
26
26
  @abstractmethod
27
- def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, vars: Vars) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
27
+ def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, var: Var) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
28
28
  """Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
29
29
 
30
- def pre_step(self, vars: Vars) -> Vars | None:
30
+ def pre_step(self, var: Var) -> Var | None:
31
31
  """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
32
32
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
33
- return vars
33
+ return var
34
34
 
35
35
  @torch.no_grad
36
- def step(self, vars):
37
- ret = self.pre_step(vars)
38
- if isinstance(ret, Vars): vars = ret
36
+ def step(self, var):
37
+ ret = self.pre_step(var)
38
+ if isinstance(ret, Var): var = ret
39
39
 
40
- if vars.closure is None: raise RuntimeError("Gradient approximation requires closure")
41
- params, closure, loss = vars.params, vars.closure, vars.loss
40
+ if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
41
+ params, closure, loss = var.params, var.closure, var.loss
42
42
 
43
43
  if self._target == 'closure':
44
44
 
45
45
  def approx_closure(backward=True):
46
46
  if backward:
47
47
  # set loss to None because closure might be evaluated at different points
48
- grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, vars=vars)
48
+ grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, var=var)
49
49
  for p, g in zip(params, grad): p.grad = g
50
50
  return l if l is not None else l_approx
51
51
  return closure(False)
52
52
 
53
- vars.closure = approx_closure
54
- return vars
53
+ var.closure = approx_closure
54
+ return var
55
55
 
56
- # if vars.grad is not None:
57
- # warnings.warn('Using grad approximator when `vars.grad` is already set.')
58
- grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, vars=vars)
59
- if loss_approx is not None: vars.loss_approx = loss_approx
60
- if loss is not None: vars.loss = vars.loss_approx = loss
61
- if self._target == 'grad': vars.grad = list(grad)
62
- elif self._target == 'update': vars.update = list(grad)
56
+ # if var.grad is not None:
57
+ # warnings.warn('Using grad approximator when `var.grad` is already set.')
58
+ grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, var=var)
59
+ if loss_approx is not None: var.loss_approx = loss_approx
60
+ if loss is not None: var.loss = var.loss_approx = loss
61
+ if self._target == 'grad': var.grad = list(grad)
62
+ elif self._target == 'update': var.update = list(grad)
63
63
  else: raise ValueError(self._target)
64
- return vars
64
+ return var
65
65
 
66
66
  _FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', 'central2', 'central4']
@@ -90,6 +90,19 @@ _RFD_FUNCS = {
90
90
 
91
91
 
92
92
  class RandomizedFDM(GradApproximator):
93
+ """_summary_
94
+
95
+ Args:
96
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
97
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
98
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
99
+ distribution (Distributions, optional): distribution. Defaults to "rademacher".
100
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
101
+ pre_generate (bool, optional):
102
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
103
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
104
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
105
+ """
93
106
  PRE_MULTIPLY_BY_H = True
94
107
  def __init__(
95
108
  self,
@@ -99,8 +112,8 @@ class RandomizedFDM(GradApproximator):
99
112
  distribution: Distributions = "rademacher",
100
113
  beta: float = 0,
101
114
  pre_generate = True,
102
- target: GradTarget = "closure",
103
115
  seed: int | None | torch.Generator = None,
116
+ target: GradTarget = "closure",
104
117
  ):
105
118
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
106
119
  super().__init__(defaults, target=target)
@@ -118,16 +131,16 @@ class RandomizedFDM(GradApproximator):
118
131
  else: self.global_state['generator'] = None
119
132
  return self.global_state['generator']
120
133
 
121
- def pre_step(self, vars):
122
- h, beta = self.get_settings('h', 'beta', params=vars.params)
123
- settings = self.settings[vars.params[0]]
134
+ def pre_step(self, var):
135
+ h, beta = self.get_settings(var.params, 'h', 'beta')
136
+ settings = self.settings[var.params[0]]
124
137
  n_samples = settings['n_samples']
125
138
  distribution = settings['distribution']
126
139
  pre_generate = settings['pre_generate']
127
140
 
128
141
  if pre_generate:
129
- params = TensorList(vars.params)
130
- generator = self._get_generator(settings['seed'], vars.params)
142
+ params = TensorList(var.params)
143
+ generator = self._get_generator(settings['seed'], var.params)
131
144
  perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
132
145
 
133
146
  if self.PRE_MULTIPLY_BY_H:
@@ -152,11 +165,11 @@ class RandomizedFDM(GradApproximator):
152
165
  torch._foreach_lerp_(cur_flat, new_flat, betas)
153
166
 
154
167
  @torch.no_grad
155
- def approximate(self, closure, params, loss, vars):
168
+ def approximate(self, closure, params, loss, var):
156
169
  params = TensorList(params)
157
170
  loss_approx = None
158
171
 
159
- h = self.get_settings('h', params=vars.params, cls=NumberList)
172
+ h = NumberList(self.settings[p]['h'] for p in params)
160
173
  settings = self.settings[params[0]]
161
174
  n_samples = settings['n_samples']
162
175
  fd_fn = _RFD_FUNCS[settings['formula']]
@@ -220,29 +233,29 @@ class MeZO(GradApproximator):
220
233
  distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
221
234
  ).mul_(h)
222
235
 
223
- def pre_step(self, vars):
224
- h = self.get_settings('h', params=vars.params)
225
- settings = self.settings[vars.params[0]]
236
+ def pre_step(self, var):
237
+ h = NumberList(self.settings[p]['h'] for p in var.params)
238
+ settings = self.settings[var.params[0]]
226
239
  n_samples = settings['n_samples']
227
240
  distribution = settings['distribution']
228
241
 
229
- step = vars.current_step
242
+ step = var.current_step
230
243
 
231
244
  # create functions that generate a deterministic perturbation from seed based on current step
232
245
  prt_fns = []
233
246
  for i in range(n_samples):
234
247
 
235
- prt_fn = partial(self._seeded_perturbation, params=vars.params, distribution=distribution, seed=1_000_000*step + i, h=h)
248
+ prt_fn = partial(self._seeded_perturbation, params=var.params, distribution=distribution, seed=1_000_000*step + i, h=h)
236
249
  prt_fns.append(prt_fn)
237
250
 
238
251
  self.global_state['prt_fns'] = prt_fns
239
252
 
240
253
  @torch.no_grad
241
- def approximate(self, closure, params, loss, vars):
254
+ def approximate(self, closure, params, loss, var):
242
255
  params = TensorList(params)
243
256
  loss_approx = None
244
257
 
245
- h = self.get_settings('h', params=vars.params, cls=NumberList)
258
+ h = NumberList(self.settings[p]['h'] for p in params)
246
259
  settings = self.settings[params[0]]
247
260
  n_samples = settings['n_samples']
248
261
  fd_fn = _RFD_FUNCS[settings['formula']]
@@ -0,0 +1 @@
1
+ from .higher_order_newton import HigherOrderNewton
@@ -0,0 +1,256 @@
1
+ import itertools
2
+ import math
3
+ import warnings
4
+ from collections.abc import Callable
5
+ from contextlib import nullcontext
6
+ from functools import partial
7
+ from typing import Any, Literal
8
+
9
+ import numpy as np
10
+ import scipy.optimize
11
+ import torch
12
+
13
+ from ...core import Chainable, Module, apply_transform
14
+ from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
+ from ...utils.derivatives import (
16
+ hessian_list_to_mat,
17
+ jacobian_wrt,
18
+ )
19
+
20
+ _LETTERS = 'abcdefghijklmnopqrstuvwxyz'
21
+ def _poly_eval(s: np.ndarray, c, derivatives):
22
+ val = float(c)
23
+ for i,T in enumerate(derivatives, 1):
24
+ s1 = ''.join(_LETTERS[:i]) # abcd
25
+ s2 = ',...'.join(_LETTERS[:i]) # a,b,c,d
26
+ # this would make einsum('abcd,a,b,c,d', T, x, x, x, x)
27
+ val += np.einsum(f"...{s1},...{s2}", T, *(s for _ in range(i))) / math.factorial(i)
28
+ return val
29
+
30
+ def _proximal_poly_v(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
31
+ if x.ndim == 2: x = x.T # DE passes (ndim, batch_size)
32
+ s = x - x0
33
+ val = _poly_eval(s, c, derivatives)
34
+ penalty = 0
35
+ if prox != 0: penalty = (prox / 2) * (s**2).sum(-1) # proximal penalty
36
+ return val + penalty
37
+
38
+ def _proximal_poly_g(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
39
+ s = x - x0
40
+ g = derivatives[0].copy()
41
+ if len(derivatives) > 1:
42
+ for i, T in enumerate(derivatives[1:], 2):
43
+ s1 = ''.join(_LETTERS[:i]) # abcd
44
+ s2 = ','.join(_LETTERS[1:i]) # b,c,d
45
+ # this would make einsum('abcd,b,c,d->a', T, x, x, x)
46
+ g += np.einsum(f"{s1},{s2}->a", T, *(s for _ in range(i-1))) / math.factorial(i - 1)
47
+
48
+ g_prox = 0
49
+ if prox != 0: g_prox = prox * s
50
+ return g + g_prox
51
+
52
+ def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
53
+ s = x - x0
54
+ n = x.shape[0]
55
+ if len(derivatives) == 1:
56
+ H = np.zeros(n, n)
57
+ else:
58
+ H = derivatives[1].copy()
59
+ if len(derivatives) > 2:
60
+ for i, T in enumerate(derivatives[2:], 3):
61
+ s1 = ''.join(_LETTERS[:i]) # abcd
62
+ s2 = ','.join(_LETTERS[2:i]) # c,d
63
+ # this would make einsum('abcd,c,d->ab', T, x, x, x)
64
+ H += np.einsum(f"{s1},{s2}->ab", T, *(s for _ in range(i-2))) / math.factorial(i - 2)
65
+
66
+ H_prox = 0
67
+ if prox != 0: H_prox = np.eye(n) * prox
68
+ return H + H_prox
69
+
70
+ def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
71
+ derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
72
+ x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
73
+ bounds = None
74
+ if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
75
+
76
+ # if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
77
+ if bounds is None:
78
+ if len(derivatives) == 1: method = 'bfgs'
79
+ else: method = 'trust-exact'
80
+ else:
81
+ if len(derivatives) == 1: method = 'l-bfgs-b'
82
+ else: method = 'trust-constr'
83
+
84
+ x_init = x0.copy()
85
+ v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
86
+ if de_iters is not None and de_iters != 0:
87
+ if de_iters == -1: de_iters = None # let scipy decide
88
+ res = scipy.optimize.differential_evolution(
89
+ _proximal_poly_v,
90
+ bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
91
+ args=(c, prox, x0.copy(), derivatives),
92
+ maxiter=de_iters,
93
+ vectorized=True,
94
+ )
95
+ if res.fun < v0: x_init = res.x
96
+
97
+ res = scipy.optimize.minimize(
98
+ _proximal_poly_v,
99
+ x_init,
100
+ method=method,
101
+ args=(c, prox, x0.copy(), derivatives),
102
+ jac=_proximal_poly_g,
103
+ hess=_proximal_poly_H,
104
+ bounds=bounds
105
+ )
106
+
107
+ return torch.from_numpy(res.x).to(x), res.fun
108
+
109
+
110
+
111
+ class HigherOrderNewton(Module):
112
+ """
113
+ A basic arbitrary order newton's method with optional trust region and proximal penalty.
114
+ It is recommended to enable at least one of trust region or proximal penalty.
115
+
116
+ This constructs an nth order taylor approximation via autograd and minimizes it with
117
+ scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
118
+
119
+ This uses n^order memory, where n is number of decision variables, and I am not aware
120
+ of any problems where this is more efficient than newton's method. It can minimize
121
+ rosenbrock in a single step, but that step probably takes more time than newton.
122
+ And there are way more efficient tensor methods out there but they tend to be
123
+ significantly more complex.
124
+
125
+ Args:
126
+
127
+ order (int, optional):
128
+ Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
129
+ trust_method (str | None, optional):
130
+ Method used for trust region.
131
+ - "bounds" - the model is minimized within bounds defined by trust region.
132
+ - "proximal" - the model is minimized with penalty for going too far from current point.
133
+ - "none" - disables trust region.
134
+
135
+ Defaults to 'bounds'.
136
+ increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
137
+ decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
138
+ trust_init (float | None, optional):
139
+ initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on :code:`"proximal"`. Defaults to None.
140
+ trust_tol (float, optional):
141
+ Maximum ratio of expected loss reduction to actual reduction for trust region increase.
142
+ Should 1 or higer. Defaults to 2.
143
+ de_iters (int | None, optional):
144
+ If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
145
+ then it is passed to scipy.optimize.minimize. Defaults to None.
146
+ vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
147
+ """
148
+ def __init__(
149
+ self,
150
+ order: int = 4,
151
+ trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
152
+ increase: float = 1.5,
153
+ decrease: float = 0.75,
154
+ trust_init: float | None = None,
155
+ trust_tol: float = 2,
156
+ de_iters: int | None = None,
157
+ vectorize: bool = True,
158
+ ):
159
+ if trust_init is None:
160
+ if trust_method == 'bounds': trust_init = 1
161
+ else: trust_init = 0.1
162
+
163
+ defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
164
+ super().__init__(defaults)
165
+
166
+ @torch.no_grad
167
+ def step(self, var):
168
+ params = TensorList(var.params)
169
+ closure = var.closure
170
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
171
+
172
+ settings = self.settings[params[0]]
173
+ order = settings['order']
174
+ increase = settings['increase']
175
+ decrease = settings['decrease']
176
+ trust_tol = settings['trust_tol']
177
+ trust_init = settings['trust_init']
178
+ trust_method = settings['trust_method']
179
+ de_iters = settings['de_iters']
180
+ vectorize = settings['vectorize']
181
+
182
+ trust_value = self.global_state.get('trust_value', trust_init)
183
+
184
+
185
+ # ------------------------ calculate grad and hessian ------------------------ #
186
+ with torch.enable_grad():
187
+ loss = var.loss = var.loss_approx = closure(False)
188
+
189
+ g_list = torch.autograd.grad(loss, params, create_graph=True)
190
+ var.grad = list(g_list)
191
+
192
+ g = torch.cat([t.ravel() for t in g_list])
193
+ n = g.numel()
194
+ derivatives = [g]
195
+ T = g # current derivatives tensor
196
+
197
+ # get all derivative up to order
198
+ for o in range(2, order + 1):
199
+ is_last = o == order
200
+ T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
201
+ with torch.no_grad() if is_last else nullcontext():
202
+ # the shape is (ndim, ) * order
203
+ T = hessian_list_to_mat(T_list).view(n, n, *T.shape[1:])
204
+ derivatives.append(T)
205
+
206
+ x0 = torch.cat([p.ravel() for p in params])
207
+
208
+ if trust_method is None: trust_method = 'none'
209
+ else: trust_method = trust_method.lower()
210
+
211
+ if trust_method == 'none':
212
+ trust_region = None
213
+ prox = 0
214
+
215
+ elif trust_method == 'bounds':
216
+ trust_region = trust_value
217
+ prox = 0
218
+
219
+ elif trust_method == 'proximal':
220
+ trust_region = None
221
+ prox = 1 / trust_value
222
+
223
+ else:
224
+ raise ValueError(trust_method)
225
+
226
+ x_star, expected_loss = _poly_minimize(
227
+ trust_region=trust_region,
228
+ prox=prox,
229
+ de_iters=de_iters,
230
+ c=loss.item(),
231
+ x=x0,
232
+ derivatives=derivatives,
233
+ )
234
+
235
+ # trust region
236
+ if trust_method != 'none':
237
+ expected_reduction = loss - expected_loss
238
+
239
+ vec_to_tensors_(x_star, params)
240
+ loss_star = closure(False)
241
+ vec_to_tensors_(x0, params)
242
+ reduction = loss - loss_star
243
+
244
+ # failed step
245
+ if reduction <= 0:
246
+ x_star = x0
247
+ self.global_state['trust_value'] = trust_value * decrease
248
+
249
+ # very good step
250
+ elif expected_reduction / reduction <= trust_tol:
251
+ self.global_state['trust_value'] = trust_value * increase
252
+
253
+ difference = vec_to_tensors(x0 - x_star, params)
254
+ var.update = list(difference)
255
+ return var
256
+
@@ -14,7 +14,6 @@ def backtracking_line_search(
14
14
  beta: float = 0.5,
15
15
  c: float = 1e-4,
16
16
  maxiter: int = 10,
17
- a_min: float | None = None,
18
17
  try_negative: bool = False,
19
18
  ) -> float | None:
20
19
  """
@@ -26,7 +25,6 @@ def backtracking_line_search(
26
25
  beta: The factor by which to decrease alpha in each iteration
27
26
  c: The constant for the Armijo sufficient decrease condition
28
27
  max_iter: Maximum number of backtracking iterations (default: 10).
29
- min_alpha: Minimum allowable step size to prevent near-zero values (default: 1e-16).
30
28
 
31
29
  Returns:
32
30
  step size
@@ -45,10 +43,6 @@ def backtracking_line_search(
45
43
  # decrease alpha
46
44
  a *= beta
47
45
 
48
- # alpha too small
49
- if a_min is not None and a < a_min:
50
- return a_min
51
-
52
46
  # fail
53
47
  if try_negative:
54
48
  def inv_objective(alpha): return f(-alpha)
@@ -59,7 +53,6 @@ def backtracking_line_search(
59
53
  beta=beta,
60
54
  c=c,
61
55
  maxiter=maxiter,
62
- a_min=a_min,
63
56
  try_negative=False,
64
57
  )
65
58
  if v is not None: return -v
@@ -67,17 +60,28 @@ def backtracking_line_search(
67
60
  return None
68
61
 
69
62
  class Backtracking(LineSearch):
63
+ """Backtracking line search satisfying the Armijo condition.
64
+
65
+ Args:
66
+ init (float, optional): initial step size. Defaults to 1.0.
67
+ beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
68
+ c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
69
+ maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
70
+ adaptive (bool, optional):
71
+ when enabled, if line search failed, initial step size is reduced.
72
+ Otherwise it is reset to initial value. Defaults to True.
73
+ try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
74
+ """
70
75
  def __init__(
71
76
  self,
72
77
  init: float = 1.0,
73
78
  beta: float = 0.5,
74
79
  c: float = 1e-4,
75
80
  maxiter: int = 10,
76
- min_alpha: float | None = None,
77
81
  adaptive=True,
78
82
  try_negative: bool = False,
79
83
  ):
80
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,adaptive=adaptive, try_negative=try_negative)
84
+ defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive, try_negative=try_negative)
81
85
  super().__init__(defaults=defaults)
82
86
  self.global_state['beta_scale'] = 1.0
83
87
 
@@ -86,20 +90,20 @@ class Backtracking(LineSearch):
86
90
  self.global_state['beta_scale'] = 1.0
87
91
 
88
92
  @torch.no_grad
89
- def search(self, update, vars):
90
- init, beta, c, maxiter, min_alpha, adaptive, try_negative = itemgetter(
91
- 'init', 'beta', 'c', 'maxiter', 'min_alpha', 'adaptive', 'try_negative')(self.settings[vars.params[0]])
93
+ def search(self, update, var):
94
+ init, beta, c, maxiter, adaptive, try_negative = itemgetter(
95
+ 'init', 'beta', 'c', 'maxiter', 'adaptive', 'try_negative')(self.settings[var.params[0]])
92
96
 
93
- objective = self.make_objective(vars=vars)
97
+ objective = self.make_objective(var=var)
94
98
 
95
99
  # # directional derivative
96
- d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
100
+ d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
97
101
 
98
102
  # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
99
103
  if adaptive: beta = beta * self.global_state['beta_scale']
100
104
 
101
105
  step_size = backtracking_line_search(objective, d, init=init,beta=beta,
102
- c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
106
+ c=c,maxiter=maxiter, try_negative=try_negative)
103
107
 
104
108
  # found an alpha that reduces loss
105
109
  if step_size is not None:
@@ -114,19 +118,34 @@ def _lerp(start,end,weight):
114
118
  return start + weight * (end - start)
115
119
 
116
120
  class AdaptiveBacktracking(LineSearch):
121
+ """Adaptive backtracking line search. After each line search procedure, a new initial step size is set
122
+ such that optimal step size in the procedure would be found on the second line search iteration.
123
+
124
+ Args:
125
+ init (float, optional): step size for the first step. Defaults to 1.0.
126
+ beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
127
+ c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
128
+ maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
129
+ target_iters (int, optional):
130
+ target number of iterations that would be performed until optimal step size is found. Defaults to 1.
131
+ nplus (float, optional):
132
+ Multiplier to initial step size if it was found to be the optimal step size. Defaults to 2.0.
133
+ scale_beta (float, optional):
134
+ Momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
135
+ try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
136
+ """
117
137
  def __init__(
118
138
  self,
119
139
  init: float = 1.0,
120
140
  beta: float = 0.5,
121
141
  c: float = 1e-4,
122
142
  maxiter: int = 20,
123
- min_alpha: float | None = None,
124
143
  target_iters = 1,
125
144
  nplus = 2.0,
126
145
  scale_beta = 0.0,
127
146
  try_negative: bool = False,
128
147
  ):
129
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
148
+ defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
130
149
  super().__init__(defaults=defaults)
131
150
 
132
151
  self.global_state['beta_scale'] = 1.0
@@ -138,15 +157,15 @@ class AdaptiveBacktracking(LineSearch):
138
157
  self.global_state['initial_scale'] = 1.0
139
158
 
140
159
  @torch.no_grad
141
- def search(self, update, vars):
142
- init, beta, c, maxiter, min_alpha, target_iters, nplus, scale_beta, try_negative=itemgetter(
143
- 'init','beta','c','maxiter','min_alpha','target_iters','nplus','scale_beta', 'try_negative')(self.settings[vars.params[0]])
160
+ def search(self, update, var):
161
+ init, beta, c, maxiter, target_iters, nplus, scale_beta, try_negative=itemgetter(
162
+ 'init','beta','c','maxiter','target_iters','nplus','scale_beta', 'try_negative')(self.settings[var.params[0]])
144
163
 
145
- objective = self.make_objective(vars=vars)
164
+ objective = self.make_objective(var=var)
146
165
 
147
166
  # directional derivative (0 if c = 0 because it is not needed)
148
167
  if c == 0: d = 0
149
- else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
168
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
150
169
 
151
170
  # scale beta
152
171
  beta = beta * self.global_state['beta_scale']
@@ -155,7 +174,7 @@ class AdaptiveBacktracking(LineSearch):
155
174
  init = init * self.global_state['initial_scale']
156
175
 
157
176
  step_size = backtracking_line_search(objective, d, init=init, beta=beta,
158
- c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
177
+ c=c,maxiter=maxiter, try_negative=try_negative)
159
178
 
160
179
  # found an alpha that reduces loss
161
180
  if step_size is not None: