torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,58 +1,73 @@
1
1
  import math
2
+ from bisect import insort
3
+ from collections import deque
2
4
  from collections.abc import Callable
3
5
  from operator import itemgetter
4
6
 
7
+ import numpy as np
5
8
  import torch
6
9
 
7
- from .line_search import LineSearchBase
8
-
10
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
9
11
 
10
12
 
11
13
  def adaptive_tracking(
12
14
  f,
13
- x_0,
15
+ a_init,
14
16
  maxiter: int,
15
17
  nplus: float = 2,
16
18
  nminus: float = 0.5,
19
+ f_0 = None,
17
20
  ):
18
- f_0 = f(0)
21
+ niter = 0
22
+ if f_0 is None: f_0 = f(0)
19
23
 
20
- t = x_0
21
- f_t = f(t)
24
+ a = a_init
25
+ f_a = f(a)
22
26
 
23
27
  # backtrack
24
- if f_t > f_0:
25
- while f_t > f_0:
28
+ a_prev = a
29
+ f_prev = math.inf
30
+ if (f_a > f_0) or (not math.isfinite(f_a)):
31
+ while (f_a < f_prev) or not math.isfinite(f_a):
32
+ a_prev, f_prev = a, f_a
26
33
  maxiter -= 1
27
- if maxiter < 0: return 0, f_0
28
- t = t*nminus
29
- f_t = f(t)
30
- return t, f_t
34
+ if maxiter < 0: break
35
+
36
+ a = a*nminus
37
+ f_a = f(a)
38
+ niter += 1
39
+
40
+ if f_prev < f_0: return a_prev, f_prev, niter
41
+ return 0, f_0, niter
31
42
 
32
43
  # forwardtrack
33
- f_prev = f_t
34
- t *= nplus
35
- f_t = f(t)
36
- if f_prev < f_t: return t / nplus, f_prev
37
- while f_prev >= f_t:
44
+ a_prev = a
45
+ f_prev = math.inf
46
+ while (f_a <= f_prev) and math.isfinite(f_a):
47
+ a_prev, f_prev = a, f_a
38
48
  maxiter -= 1
39
- if maxiter < 0: return t, f_t
40
- f_prev = f_t
41
- t *= nplus
42
- f_t = f(t)
43
- return t / nplus, f_prev
49
+ if maxiter < 0: break
50
+
51
+ a *= nplus
52
+ f_a = f(a)
53
+ niter+= 1
54
+
55
+ if f_prev < f_0: return a_prev, f_prev, niter
56
+ return 0, f_0, niter
57
+
44
58
 
45
- class AdaptiveLineSearch(LineSearchBase):
46
- """Adaptive line search, similar to backtracking but also has forward tracking mode.
47
- Currently doesn't check for weak curvature condition.
59
+ class AdaptiveTracking(LineSearchBase):
60
+ """A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
61
+ otherwise forward-tracks until value stops decreasing.
48
62
 
49
63
  Args:
50
64
  init (float, optional): initial step size. Defaults to 1.0.
51
- beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
52
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
65
+ nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
66
+ nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
67
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
53
68
  adaptive (bool, optional):
54
- when enabled, if line search failed, beta size is reduced.
55
- Otherwise it is reset to initial value. Defaults to True.
69
+ when enabled, if line search failed, step size will continue decreasing on the next step.
70
+ Otherwise it will restart the line search from ``init`` step size. Defaults to True.
56
71
  """
57
72
  def __init__(
58
73
  self,
@@ -62,38 +77,48 @@ class AdaptiveLineSearch(LineSearchBase):
62
77
  maxiter: int = 10,
63
78
  adaptive=True,
64
79
  ):
65
- defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive,)
80
+ defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
66
81
  super().__init__(defaults=defaults)
67
- self.global_state['beta_scale'] = 1.0
68
82
 
69
83
  def reset(self):
70
84
  super().reset()
71
- self.global_state['beta_scale'] = 1.0
72
85
 
73
86
  @torch.no_grad
74
87
  def search(self, update, var):
75
88
  init, nplus, nminus, maxiter, adaptive = itemgetter(
76
- 'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.settings[var.params[0]])
89
+ 'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)
77
90
 
78
91
  objective = self.make_objective(var=var)
79
92
 
80
- # # directional derivative
81
- # d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
93
+ # scale a_prev
94
+ a_prev = self.global_state.get('a_prev', init)
95
+ if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)
82
96
 
83
- # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
84
- beta_scale = self.global_state.get('beta_scale', 1)
85
- x_prev = self.global_state.get('prev_x', 1)
97
+ a_init = a_prev
98
+ if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
99
+ a_init = torch.finfo(var.params[0].dtype).max / 2
86
100
 
87
- if adaptive: nminus = nminus * beta_scale
88
-
89
-
90
- step_size, f = adaptive_tracking(objective, x_prev, maxiter, nplus=nplus, nminus=nminus)
101
+ step_size, f, niter = adaptive_tracking(
102
+ objective,
103
+ a_init=a_init,
104
+ maxiter=maxiter,
105
+ nplus=nplus,
106
+ nminus=nminus,
107
+ )
91
108
 
92
109
  # found an alpha that reduces loss
93
110
  if step_size != 0:
94
- self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
111
+ assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
112
+ self.global_state['init_scale'] = 1
113
+
114
+ # if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
115
+ if niter == 1 and step_size >= a_init: step_size *= nminus
116
+
117
+ self.global_state['a_prev'] = step_size
95
118
  return step_size
96
119
 
97
120
  # on fail reduce beta scale value
98
- self.global_state['beta_scale'] /= 1.5
121
+ self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
122
+ self.global_state['a_prev'] = init
99
123
  return 0
124
+
@@ -4,7 +4,7 @@ from operator import itemgetter
4
4
 
5
5
  import torch
6
6
 
7
- from .line_search import LineSearchBase
7
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
8
8
 
9
9
 
10
10
  def backtracking_line_search(
@@ -14,7 +14,7 @@ def backtracking_line_search(
14
14
  beta: float = 0.5,
15
15
  c: float = 1e-4,
16
16
  maxiter: int = 10,
17
- try_negative: bool = False,
17
+ condition: TerminationCondition = 'armijo',
18
18
  ) -> float | None:
19
19
  """
20
20
 
@@ -31,16 +31,20 @@ def backtracking_line_search(
31
31
  """
32
32
 
33
33
  a = init
34
- f_x = f(0)
34
+ f_0 = f(0)
35
35
  f_prev = None
36
36
 
37
37
  for iteration in range(maxiter):
38
38
  f_a = f(a)
39
+ if not math.isfinite(f_a):
40
+ a *= beta
41
+ continue
39
42
 
40
- if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_x): return a / beta
43
+ if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_0):
44
+ return a / beta # new value is larger than previous value
41
45
  f_prev = f_a
42
46
 
43
- if f_a < f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
47
+ if termination_condition(condition, f_0=f_0, g_0=g_0, f_a=f_a, g_a=None, a=a, c=c):
44
48
  # found an acceptable alpha
45
49
  return a
46
50
 
@@ -48,53 +52,45 @@ def backtracking_line_search(
48
52
  a *= beta
49
53
 
50
54
  # fail
51
- if try_negative:
52
- def inv_objective(alpha): return f(-alpha)
53
-
54
- v = backtracking_line_search(
55
- inv_objective,
56
- g_0=-g_0,
57
- beta=beta,
58
- c=c,
59
- maxiter=maxiter,
60
- try_negative=False,
61
- )
62
- if v is not None: return -v
63
-
64
55
  return None
65
56
 
66
57
  class Backtracking(LineSearchBase):
67
- """Backtracking line search satisfying the Armijo condition.
58
+ """Backtracking line search.
68
59
 
69
60
  Args:
70
61
  init (float, optional): initial step size. Defaults to 1.0.
71
62
  beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
72
- c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
73
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
63
+ c (float, optional): sufficient decrease condition. Defaults to 1e-4.
64
+ condition (TerminationCondition, optional):
65
+ termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
66
+ - "armijo" - sufficient decrease condition.
67
+ - "decrease" - any decrease in objective function value satisfies the condition.
68
+
69
+ "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
70
+ Defaults to 'armijo'.
71
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
74
72
  adaptive (bool, optional):
75
- when enabled, if line search failed, beta is reduced.
76
- Otherwise it is reset to initial value. Defaults to True.
77
- try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
73
+ when enabled, if line search failed, step size will continue decreasing on the next step.
74
+ Otherwise it will restart the line search from ``init`` step size. Defaults to True.
78
75
 
79
76
  Examples:
80
- Gradient descent with backtracking line search:
81
-
82
- .. code-block:: python
83
-
84
- opt = tz.Modular(
85
- model.parameters(),
86
- tz.m.Backtracking()
87
- )
88
-
89
- LBFGS with backtracking line search:
90
-
91
- .. code-block:: python
92
-
93
- opt = tz.Modular(
94
- model.parameters(),
95
- tz.m.LBFGS(),
96
- tz.m.Backtracking()
97
- )
77
+ Gradient descent with backtracking line search:
78
+
79
+ ```python
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.Backtracking()
83
+ )
84
+ ```
85
+
86
+ L-BFGS with backtracking line search:
87
+ ```python
88
+ opt = tz.Modular(
89
+ model.parameters(),
90
+ tz.m.LBFGS(),
91
+ tz.m.Backtracking()
92
+ )
93
+ ```
98
94
 
99
95
  """
100
96
  def __init__(
@@ -102,41 +98,47 @@ class Backtracking(LineSearchBase):
102
98
  init: float = 1.0,
103
99
  beta: float = 0.5,
104
100
  c: float = 1e-4,
101
+ condition: TerminationCondition = 'armijo',
105
102
  maxiter: int = 10,
106
103
  adaptive=True,
107
- try_negative: bool = False,
108
104
  ):
109
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive, try_negative=try_negative)
105
+ defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
110
106
  super().__init__(defaults=defaults)
111
- self.global_state['beta_scale'] = 1.0
112
107
 
113
108
  def reset(self):
114
109
  super().reset()
115
- self.global_state['beta_scale'] = 1.0
116
110
 
117
111
  @torch.no_grad
118
112
  def search(self, update, var):
119
- init, beta, c, maxiter, adaptive, try_negative = itemgetter(
120
- 'init', 'beta', 'c', 'maxiter', 'adaptive', 'try_negative')(self.settings[var.params[0]])
113
+ init, beta, c, condition, maxiter, adaptive = itemgetter(
114
+ 'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)
121
115
 
122
116
  objective = self.make_objective(var=var)
123
117
 
124
118
  # # directional derivative
125
- d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
119
+ if c == 0: d = 0
120
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
126
121
 
127
- # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
128
- if adaptive: beta = beta * self.global_state['beta_scale']
122
+ # scale init
123
+ init_scale = self.global_state.get('init_scale', 1)
124
+ if adaptive: init = init * init_scale
129
125
 
130
- step_size = backtracking_line_search(objective, d, init=init,beta=beta,
131
- c=c,maxiter=maxiter, try_negative=try_negative)
126
+ step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)
132
127
 
133
128
  # found an alpha that reduces loss
134
129
  if step_size is not None:
135
- self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
130
+ #self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
131
+ self.global_state['init_scale'] = 1
136
132
  return step_size
137
133
 
138
- # on fail reduce beta scale value
139
- self.global_state['beta_scale'] /= 1.5
134
+ # on fail set init_scale to continue decreasing the step size
135
+ # or set to large step size when it becomes too small
136
+ if adaptive:
137
+ finfo = torch.finfo(var.params[0].dtype)
138
+ if init_scale <= finfo.tiny * 2:
139
+ self.global_state["init_scale"] = finfo.max / 2
140
+ else:
141
+ self.global_state['init_scale'] = init_scale * beta**maxiter
140
142
  return 0
141
143
 
142
144
  def _lerp(start,end,weight):
@@ -147,30 +149,37 @@ class AdaptiveBacktracking(LineSearchBase):
147
149
  such that optimal step size in the procedure would be found on the second line search iteration.
148
150
 
149
151
  Args:
150
- init (float, optional): step size for the first step. Defaults to 1.0.
152
+ init (float, optional): initial step size. Defaults to 1.0.
151
153
  beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
152
- c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
153
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
154
+ c (float, optional): sufficient decrease condition. Defaults to 1e-4.
155
+ condition (TerminationCondition, optional):
156
+ termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
157
+ - "armijo" - sufficient decrease condition.
158
+ - "decrease" - any decrease in objective function value satisfies the condition.
159
+
160
+ "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
161
+ Defaults to 'armijo'.
162
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
154
163
  target_iters (int, optional):
155
- target number of iterations that would be performed until optimal step size is found. Defaults to 1.
164
+ sets next step size such that this number of iterations are expected
165
+ to be performed until optimal step size is found. Defaults to 1.
156
166
  nplus (float, optional):
157
- Multiplier to initial step size if it was found to be the optimal step size. Defaults to 2.0.
167
+ if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
158
168
  scale_beta (float, optional):
159
- Momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
160
- try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
169
+ momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
161
170
  """
162
171
  def __init__(
163
172
  self,
164
173
  init: float = 1.0,
165
174
  beta: float = 0.5,
166
175
  c: float = 1e-4,
176
+ condition: TerminationCondition = 'armijo',
167
177
  maxiter: int = 20,
168
178
  target_iters = 1,
169
179
  nplus = 2.0,
170
180
  scale_beta = 0.0,
171
- try_negative: bool = False,
172
181
  ):
173
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
182
+ defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
174
183
  super().__init__(defaults=defaults)
175
184
 
176
185
  self.global_state['beta_scale'] = 1.0
@@ -183,8 +192,8 @@ class AdaptiveBacktracking(LineSearchBase):
183
192
 
184
193
  @torch.no_grad
185
194
  def search(self, update, var):
186
- init, beta, c, maxiter, target_iters, nplus, scale_beta, try_negative=itemgetter(
187
- 'init','beta','c','maxiter','target_iters','nplus','scale_beta', 'try_negative')(self.settings[var.params[0]])
195
+ init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
196
+ 'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)
188
197
 
189
198
  objective = self.make_objective(var=var)
190
199
 
@@ -198,8 +207,7 @@ class AdaptiveBacktracking(LineSearchBase):
198
207
  # scale step size so that decrease is expected at target_iters
199
208
  init = init * self.global_state['initial_scale']
200
209
 
201
- step_size = backtracking_line_search(objective, d, init=init, beta=beta,
202
- c=c,maxiter=maxiter, try_negative=try_negative)
210
+ step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)
203
211
 
204
212
  # found an alpha that reduces loss
205
213
  if step_size is not None:
@@ -208,7 +216,12 @@ class AdaptiveBacktracking(LineSearchBase):
208
216
  # initial step size satisfied conditions, increase initial_scale by nplus
209
217
  if step_size == init and target_iters > 0:
210
218
  self.global_state['initial_scale'] *= nplus ** target_iters
211
- self.global_state['initial_scale'] = min(self.global_state['initial_scale'], 1e32) # avoid overflow error
219
+
220
+ # clip by maximum possibel value to avoid overflow exception
221
+ self.global_state['initial_scale'] = min(
222
+ self.global_state['initial_scale'],
223
+ torch.finfo(var.params[0].dtype).max / 2,
224
+ )
212
225
 
213
226
  else:
214
227
  # otherwise make initial_scale such that target_iters iterations will satisfy armijo