torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -4,7 +4,7 @@ from operator import itemgetter
4
4
 
5
5
  import torch
6
6
 
7
- from .line_search import LineSearchBase
7
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
8
8
 
9
9
 
10
10
  def backtracking_line_search(
@@ -14,7 +14,7 @@ def backtracking_line_search(
14
14
  beta: float = 0.5,
15
15
  c: float = 1e-4,
16
16
  maxiter: int = 10,
17
- try_negative: bool = False,
17
+ condition: TerminationCondition = 'armijo',
18
18
  ) -> float | None:
19
19
  """
20
20
 
@@ -31,16 +31,20 @@ def backtracking_line_search(
31
31
  """
32
32
 
33
33
  a = init
34
- f_x = f(0)
34
+ f_0 = f(0)
35
35
  f_prev = None
36
36
 
37
37
  for iteration in range(maxiter):
38
38
  f_a = f(a)
39
+ if not math.isfinite(f_a):
40
+ a *= beta
41
+ continue
39
42
 
40
- if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_x): return a / beta
43
+ if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_0):
44
+ return a / beta # new value is larger than previous value
41
45
  f_prev = f_a
42
46
 
43
- if f_a < f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
47
+ if termination_condition(condition, f_0=f_0, g_0=g_0, f_a=f_a, g_a=None, a=a, c=c):
44
48
  # found an acceptable alpha
45
49
  return a
46
50
 
@@ -48,53 +52,45 @@ def backtracking_line_search(
48
52
  a *= beta
49
53
 
50
54
  # fail
51
- if try_negative:
52
- def inv_objective(alpha): return f(-alpha)
53
-
54
- v = backtracking_line_search(
55
- inv_objective,
56
- g_0=-g_0,
57
- beta=beta,
58
- c=c,
59
- maxiter=maxiter,
60
- try_negative=False,
61
- )
62
- if v is not None: return -v
63
-
64
55
  return None
65
56
 
66
57
  class Backtracking(LineSearchBase):
67
- """Backtracking line search satisfying the Armijo condition.
58
+ """Backtracking line search.
68
59
 
69
60
  Args:
70
61
  init (float, optional): initial step size. Defaults to 1.0.
71
62
  beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
72
- c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
73
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
63
+ c (float, optional): sufficient decrease condition. Defaults to 1e-4.
64
+ condition (TerminationCondition, optional):
65
+ termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
66
+ - "armijo" - sufficient decrease condition.
67
+ - "decrease" - any decrease in objective function value satisfies the condition.
68
+
69
+ "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
70
+ Defaults to 'armijo'.
71
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
74
72
  adaptive (bool, optional):
75
- when enabled, if line search failed, beta is reduced.
76
- Otherwise it is reset to initial value. Defaults to True.
77
- try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
73
+ when enabled, if line search failed, step size will continue decreasing on the next step.
74
+ Otherwise it will restart the line search from ``init`` step size. Defaults to True.
78
75
 
79
76
  Examples:
80
- Gradient descent with backtracking line search:
81
-
82
- .. code-block:: python
83
-
84
- opt = tz.Modular(
85
- model.parameters(),
86
- tz.m.Backtracking()
87
- )
88
-
89
- LBFGS with backtracking line search:
90
-
91
- .. code-block:: python
92
-
93
- opt = tz.Modular(
94
- model.parameters(),
95
- tz.m.LBFGS(),
96
- tz.m.Backtracking()
97
- )
77
+ Gradient descent with backtracking line search:
78
+
79
+ ```python
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.Backtracking()
83
+ )
84
+ ```
85
+
86
+ L-BFGS with backtracking line search:
87
+ ```python
88
+ opt = tz.Modular(
89
+ model.parameters(),
90
+ tz.m.LBFGS(),
91
+ tz.m.Backtracking()
92
+ )
93
+ ```
98
94
 
99
95
  """
100
96
  def __init__(
@@ -102,41 +98,47 @@ class Backtracking(LineSearchBase):
102
98
  init: float = 1.0,
103
99
  beta: float = 0.5,
104
100
  c: float = 1e-4,
101
+ condition: TerminationCondition = 'armijo',
105
102
  maxiter: int = 10,
106
103
  adaptive=True,
107
- try_negative: bool = False,
108
104
  ):
109
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive, try_negative=try_negative)
105
+ defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
110
106
  super().__init__(defaults=defaults)
111
- self.global_state['beta_scale'] = 1.0
112
107
 
113
108
  def reset(self):
114
109
  super().reset()
115
- self.global_state['beta_scale'] = 1.0
116
110
 
117
111
  @torch.no_grad
118
112
  def search(self, update, var):
119
- init, beta, c, maxiter, adaptive, try_negative = itemgetter(
120
- 'init', 'beta', 'c', 'maxiter', 'adaptive', 'try_negative')(self.settings[var.params[0]])
113
+ init, beta, c, condition, maxiter, adaptive = itemgetter(
114
+ 'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)
121
115
 
122
116
  objective = self.make_objective(var=var)
123
117
 
124
118
  # # directional derivative
125
- d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
119
+ if c == 0: d = 0
120
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
126
121
 
127
- # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
128
- if adaptive: beta = beta * self.global_state['beta_scale']
122
+ # scale init
123
+ init_scale = self.global_state.get('init_scale', 1)
124
+ if adaptive: init = init * init_scale
129
125
 
130
- step_size = backtracking_line_search(objective, d, init=init,beta=beta,
131
- c=c,maxiter=maxiter, try_negative=try_negative)
126
+ step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)
132
127
 
133
128
  # found an alpha that reduces loss
134
129
  if step_size is not None:
135
- self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
130
+ #self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
131
+ self.global_state['init_scale'] = 1
136
132
  return step_size
137
133
 
138
- # on fail reduce beta scale value
139
- self.global_state['beta_scale'] /= 1.5
134
+ # on fail set init_scale to continue decreasing the step size
135
+ # or set to large step size when it becomes too small
136
+ if adaptive:
137
+ finfo = torch.finfo(var.params[0].dtype)
138
+ if init_scale <= finfo.tiny * 2:
139
+ self.global_state["init_scale"] = finfo.max / 2
140
+ else:
141
+ self.global_state['init_scale'] = init_scale * beta**maxiter
140
142
  return 0
141
143
 
142
144
  def _lerp(start,end,weight):
@@ -147,30 +149,37 @@ class AdaptiveBacktracking(LineSearchBase):
147
149
  such that optimal step size in the procedure would be found on the second line search iteration.
148
150
 
149
151
  Args:
150
- init (float, optional): step size for the first step. Defaults to 1.0.
152
+ init (float, optional): initial step size. Defaults to 1.0.
151
153
  beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
152
- c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
153
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
154
+ c (float, optional): sufficient decrease condition. Defaults to 1e-4.
155
+ condition (TerminationCondition, optional):
156
+ termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
157
+ - "armijo" - sufficient decrease condition.
158
+ - "decrease" - any decrease in objective function value satisfies the condition.
159
+
160
+ "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
161
+ Defaults to 'armijo'.
162
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
154
163
  target_iters (int, optional):
155
- target number of iterations that would be performed until optimal step size is found. Defaults to 1.
164
+ sets next step size such that this number of iterations are expected
165
+ to be performed until optimal step size is found. Defaults to 1.
156
166
  nplus (float, optional):
157
- Multiplier to initial step size if it was found to be the optimal step size. Defaults to 2.0.
167
+ if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
158
168
  scale_beta (float, optional):
159
- Momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
160
- try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
169
+ momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
161
170
  """
162
171
  def __init__(
163
172
  self,
164
173
  init: float = 1.0,
165
174
  beta: float = 0.5,
166
175
  c: float = 1e-4,
176
+ condition: TerminationCondition = 'armijo',
167
177
  maxiter: int = 20,
168
178
  target_iters = 1,
169
179
  nplus = 2.0,
170
180
  scale_beta = 0.0,
171
- try_negative: bool = False,
172
181
  ):
173
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
182
+ defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
174
183
  super().__init__(defaults=defaults)
175
184
 
176
185
  self.global_state['beta_scale'] = 1.0
@@ -183,8 +192,8 @@ class AdaptiveBacktracking(LineSearchBase):
183
192
 
184
193
  @torch.no_grad
185
194
  def search(self, update, var):
186
- init, beta, c, maxiter, target_iters, nplus, scale_beta, try_negative=itemgetter(
187
- 'init','beta','c','maxiter','target_iters','nplus','scale_beta', 'try_negative')(self.settings[var.params[0]])
195
+ init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
196
+ 'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)
188
197
 
189
198
  objective = self.make_objective(var=var)
190
199
 
@@ -198,8 +207,7 @@ class AdaptiveBacktracking(LineSearchBase):
198
207
  # scale step size so that decrease is expected at target_iters
199
208
  init = init * self.global_state['initial_scale']
200
209
 
201
- step_size = backtracking_line_search(objective, d, init=init, beta=beta,
202
- c=c,maxiter=maxiter, try_negative=try_negative)
210
+ step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)
203
211
 
204
212
  # found an alpha that reduces loss
205
213
  if step_size is not None:
@@ -208,7 +216,12 @@ class AdaptiveBacktracking(LineSearchBase):
208
216
  # initial step size satisfied conditions, increase initial_scale by nplus
209
217
  if step_size == init and target_iters > 0:
210
218
  self.global_state['initial_scale'] *= nplus ** target_iters
211
- self.global_state['initial_scale'] = min(self.global_state['initial_scale'], 1e32) # avoid overflow error
219
+
220
+ # clip by maximum possibel value to avoid overflow exception
221
+ self.global_state['initial_scale'] = min(
222
+ self.global_state['initial_scale'],
223
+ torch.finfo(var.params[0].dtype).max / 2,
224
+ )
212
225
 
213
226
  else:
214
227
  # otherwise make initial_scale such that target_iters iterations will satisfy armijo
@@ -3,13 +3,13 @@ from abc import ABC, abstractmethod
3
3
  from collections.abc import Sequence
4
4
  from functools import partial
5
5
  from operator import itemgetter
6
- from typing import Any
6
+ from typing import Any, Literal
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
10
 
11
11
  from ...core import Module, Target, Var
12
- from ...utils import tofloat
12
+ from ...utils import tofloat, set_storage_
13
13
 
14
14
 
15
15
  class MaxLineSearchItersReached(Exception): pass
@@ -29,60 +29,59 @@ class LineSearchBase(Module, ABC):
29
29
  doesn't have a maxiter option. Defaults to None.
30
30
 
31
31
  Other useful methods:
32
- * `evaluate_step_size` - returns loss with a given scalar step size
33
- * `evaluate_step_size_loss_and_derivative` - returns loss and directional derivative with a given scalar step size
34
- * `make_objective` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
35
- * `make_objective_with_derivative` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
32
+ * ``evaluate_f`` - returns loss with a given scalar step size
33
+ * ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
34
+ * ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
35
+ * ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
36
36
 
37
37
  Examples:
38
- #### Basic line search
39
38
 
40
- This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
39
+ #### Basic line search
41
40
 
42
- .. code-block:: python
41
+ This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
42
+ ```python
43
+ class GridLineSearch(LineSearch):
44
+ def __init__(self, start, end, num):
45
+ defaults = dict(start=start,end=end,num=num)
46
+ super().__init__(defaults)
43
47
 
44
- class GridLineSearch(LineSearch):
45
- def __init__(self, start, end, num):
46
- defaults = dict(start=start,end=end,num=num)
47
- super().__init__(defaults)
48
+ @torch.no_grad
49
+ def search(self, update, var):
48
50
 
49
- @torch.no_grad
50
- def search(self, update, var):
51
- settings = self.settings[var.params[0]]
52
- start = settings["start"]
53
- end = settings["end"]
54
- num = settings["num"]
51
+ start = self.defaults["start"]
52
+ end = self.defaults["end"]
53
+ num = self.defaults["num"]
55
54
 
56
- lowest_loss = float("inf")
57
- best_step_size = best_step_size
55
+ lowest_loss = float("inf")
56
+ best_step_size = best_step_size
58
57
 
59
- for step_size in torch.linspace(start,end,num):
60
- loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
61
- if loss < lowest_loss:
62
- lowest_loss = loss
63
- best_step_size = step_size
58
+ for step_size in torch.linspace(start,end,num):
59
+ loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
60
+ if loss < lowest_loss:
61
+ lowest_loss = loss
62
+ best_step_size = step_size
64
63
 
65
- return best_step_size
64
+ return best_step_size
65
+ ```
66
66
 
67
- #### Using external solver via self.make_objective
67
+ #### Using external solver via self.make_objective
68
68
 
69
- Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
69
+ Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
70
70
 
71
- .. code-block:: python
71
+ ```python
72
+ class ScipyMinimizeScalar(LineSearch):
73
+ def __init__(self, method: str | None = None):
74
+ defaults = dict(method=method)
75
+ super().__init__(defaults)
72
76
 
73
- class ScipyMinimizeScalar(LineSearch):
74
- def __init__(self, method: str | None = None):
75
- defaults = dict(method=method)
76
- super().__init__(defaults)
77
-
78
- @torch.no_grad
79
- def search(self, update, var):
80
- objective = self.make_objective(var=var)
81
- method = self.settings[var.params[0]]["method"]
82
-
83
- res = self.scopt.minimize_scalar(objective, method=method)
84
- return res.x
77
+ @torch.no_grad
78
+ def search(self, update, var):
79
+ objective = self.make_objective(var=var)
80
+ method = self.defaults["method"]
85
81
 
82
+ res = self.scopt.minimize_scalar(objective, method=method)
83
+ return res.x
84
+ ```
86
85
  """
87
86
  def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
88
87
  super().__init__(defaults)
@@ -94,6 +93,7 @@ class LineSearchBase(Module, ABC):
94
93
  self._lowest_loss = float('inf')
95
94
  self._best_step_size: float = 0
96
95
  self._current_iter = 0
96
+ self._initial_params = None
97
97
 
98
98
  def set_step_size_(
99
99
  self,
@@ -102,10 +102,27 @@ class LineSearchBase(Module, ABC):
102
102
  update: list[torch.Tensor],
103
103
  ):
104
104
  if not math.isfinite(step_size): return
105
- step_size = max(min(tofloat(step_size), 1e36), -1e36) # fixes overflow when backtracking keeps increasing alpha after converging
106
- alpha = self._current_step_size - step_size
107
- if alpha != 0:
108
- torch._foreach_add_(params, update, alpha=alpha)
105
+
106
+ # fixes overflow when backtracking keeps increasing alpha after converging
107
+ step_size = max(min(tofloat(step_size), 1e36), -1e36)
108
+
109
+ # skip is parameters are already at suggested step size
110
+ if self._current_step_size == step_size: return
111
+
112
+ # this was basically causing floating point imprecision to build up
113
+ #if False:
114
+ # if abs(alpha) < abs(step_size) and step_size != 0:
115
+ # torch._foreach_add_(params, update, alpha=alpha)
116
+
117
+ # else:
118
+ assert self._initial_params is not None
119
+ if step_size == 0:
120
+ new_params = [p.clone() for p in self._initial_params]
121
+ else:
122
+ new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
123
+ for c, n in zip(params, new_params):
124
+ set_storage_(c, n)
125
+
109
126
  self._current_step_size = step_size
110
127
 
111
128
  def _set_per_parameter_step_size_(
@@ -114,10 +131,20 @@ class LineSearchBase(Module, ABC):
114
131
  params: list[torch.Tensor],
115
132
  update: list[torch.Tensor],
116
133
  ):
117
- if not np.isfinite(step_size): step_size = [0 for _ in step_size]
118
- alpha = [self._current_step_size - s for s in step_size]
119
- if any(a!=0 for a in alpha):
120
- torch._foreach_add_(params, torch._foreach_mul(update, alpha))
134
+ # if not np.isfinite(step_size): step_size = [0 for _ in step_size]
135
+ # alpha = [self._current_step_size - s for s in step_size]
136
+ # if any(a!=0 for a in alpha):
137
+ # torch._foreach_add_(params, torch._foreach_mul(update, alpha))
138
+ assert self._initial_params is not None
139
+ if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
140
+
141
+ if any(s!=0 for s in step_size):
142
+ new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
143
+ else:
144
+ new_params = [p.clone() for p in self._initial_params]
145
+
146
+ for c, n in zip(params, new_params):
147
+ set_storage_(c, n)
121
148
 
122
149
  def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
123
150
  update: list[torch.Tensor], backward:bool=False) -> float:
@@ -149,7 +176,7 @@ class LineSearchBase(Module, ABC):
149
176
 
150
177
  return tofloat(loss)
151
178
 
152
- def _loss_derivative(self, step_size: float, var: Var, closure,
179
+ def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
153
180
  params: list[torch.Tensor], update: list[torch.Tensor]):
154
181
  # if step_size is 0, we might already know the derivative
155
182
  if (var.grad is not None) and (step_size == 0):
@@ -164,18 +191,31 @@ class LineSearchBase(Module, ABC):
164
191
  derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
165
192
  else torch.zeros_like(p) for p in params], update))
166
193
 
167
- return loss, tofloat(derivative)
194
+ assert var.grad is not None
195
+ return loss, tofloat(derivative), var.grad
168
196
 
169
- def evaluate_step_size(self, step_size: float, var: Var, backward:bool=False):
197
+ def _loss_derivative(self, step_size: float, var: Var, closure,
198
+ params: list[torch.Tensor], update: list[torch.Tensor]):
199
+ return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
200
+
201
+ def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
202
+ """evaluate function value at alpha `step_size`."""
170
203
  closure = var.closure
171
204
  if closure is None: raise RuntimeError('line search requires closure')
172
205
  return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
173
206
 
174
- def evaluate_step_size_loss_and_derivative(self, step_size: float, var: Var):
207
+ def evaluate_f_d(self, step_size: float, var: Var):
208
+ """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
175
209
  closure = var.closure
176
210
  if closure is None: raise RuntimeError('line search requires closure')
177
211
  return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
178
212
 
213
+ def evaluate_f_d_g(self, step_size: float, var: Var):
214
+ """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
215
+ closure = var.closure
216
+ if closure is None: raise RuntimeError('line search requires closure')
217
+ return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
218
+
179
219
  def make_objective(self, var: Var, backward:bool=False):
180
220
  closure = var.closure
181
221
  if closure is None: raise RuntimeError('line search requires closure')
@@ -186,6 +226,11 @@ class LineSearchBase(Module, ABC):
186
226
  if closure is None: raise RuntimeError('line search requires closure')
187
227
  return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
188
228
 
229
+ def make_objective_with_derivative_and_gradient(self, var: Var):
230
+ closure = var.closure
231
+ if closure is None: raise RuntimeError('line search requires closure')
232
+ return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
233
+
189
234
  @abstractmethod
190
235
  def search(self, update: list[torch.Tensor], var: Var) -> float:
191
236
  """Finds the step size to use"""
@@ -193,7 +238,9 @@ class LineSearchBase(Module, ABC):
193
238
  @torch.no_grad
194
239
  def step(self, var: Var) -> Var:
195
240
  self._reset()
241
+
196
242
  params = var.params
243
+ self._initial_params = [p.clone() for p in params]
197
244
  update = var.get_update()
198
245
 
199
246
  try:
@@ -206,7 +253,6 @@ class LineSearchBase(Module, ABC):
206
253
 
207
254
  # this is last module - set step size to found step_size times lr
208
255
  if var.is_last:
209
-
210
256
  if var.last_module_lrs is None:
211
257
  self.set_step_size_(step_size, params=params, update=update)
212
258
 
@@ -223,17 +269,62 @@ class LineSearchBase(Module, ABC):
223
269
 
224
270
 
225
271
 
226
- # class GridLineSearch(LineSearch):
227
- # """Mostly for testing, this is not practical"""
228
- # def __init__(self, start, end, num):
229
- # defaults = dict(start=start,end=end,num=num)
230
- # super().__init__(defaults)
231
-
232
- # @torch.no_grad
233
- # def search(self, update, var):
234
- # start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
235
-
236
- # for lr in torch.linspace(start,end,num):
237
- # self.evaluate_step_size(lr.item(), var=var, backward=False)
272
+ class GridLineSearch(LineSearchBase):
273
+ """"""
274
+ def __init__(self, start, end, num):
275
+ defaults = dict(start=start,end=end,num=num)
276
+ super().__init__(defaults)
238
277
 
239
- # return self._best_step_size
278
+ @torch.no_grad
279
+ def search(self, update, var):
280
+ start,end,num=itemgetter('start','end','num')(self.defaults)
281
+
282
+ for lr in torch.linspace(start,end,num):
283
+ self.evaluate_f(lr.item(), var=var, backward=False)
284
+
285
+ return self._best_step_size
286
+
287
+
288
+ def sufficient_decrease(f_0, g_0, f_a, a, c):
289
+ return f_a < f_0 + c*a*min(g_0, 0)
290
+
291
+ def curvature(g_0, g_a, c):
292
+ if g_0 > 0: return True
293
+ return g_a >= c * g_0
294
+
295
+ def strong_curvature(g_0, g_a, c):
296
+ """same as curvature condition except curvature can't be too positive (which indicates overstep)"""
297
+ if g_0 > 0: return True
298
+ return abs(g_a) <= c * abs(g_0)
299
+
300
+ def wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
301
+ return sufficient_decrease(f_0, g_0, f_a, a, c1) and curvature(g_0, g_a, c2)
302
+
303
+ def strong_wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
304
+ return sufficient_decrease(f_0, g_0, f_a, a, c1) and strong_curvature(g_0, g_a, c2)
305
+
306
+ def goldstein(f_0, g_0, f_a, a, c):
307
+ """same as armijo (sufficient_decrease) but additional lower bound"""
308
+ g_0 = min(g_0, 0)
309
+ return f_0 + (1-c)*a*g_0 < f_a < f_0 + c*a*g_0
310
+
311
+ TerminationCondition = Literal["armijo", "curvature", "strong_curvature", "wolfe", "strong_wolfe", "goldstein", "decrease"]
312
+ def termination_condition(
313
+ condition: TerminationCondition,
314
+ f_0,
315
+ g_0,
316
+ f_a,
317
+ g_a: Any | None,
318
+ a,
319
+ c,
320
+ c2=None,
321
+ ):
322
+ if not math.isfinite(f_a): return False
323
+ if condition == 'armijo': return sufficient_decrease(f_0, g_0, f_a, a, c)
324
+ if condition == 'curvature': return curvature(g_0, g_a, c)
325
+ if condition == 'strong_curvature': return strong_curvature(g_0, g_a, c)
326
+ if condition == 'wolfe': return wolfe(f_0, g_0, f_a, g_a, a, c, c2)
327
+ if condition == 'strong_wolfe': return strong_wolfe(f_0, g_0, f_a, g_a, a, c, c2)
328
+ if condition == 'goldstein': return goldstein(f_0, g_0, f_a, a, c)
329
+ if condition == 'decrease': return f_a < f_0
330
+ raise ValueError(f"unknown condition {condition}")
@@ -1,3 +1,4 @@
1
+ import math
1
2
  from collections.abc import Mapping
2
3
  from operator import itemgetter
3
4
 
@@ -17,6 +18,7 @@ class ScipyMinimizeScalar(LineSearchBase):
17
18
  bounds (Sequence | None, optional):
18
19
  For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
19
20
  tol (float | None, optional): Tolerance for termination. Defaults to None.
21
+ prev_init (bool, optional): uses previous step size as initial guess for the line search.
20
22
  options (dict | None, optional): A dictionary of solver options. Defaults to None.
21
23
 
22
24
  For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
@@ -29,9 +31,10 @@ class ScipyMinimizeScalar(LineSearchBase):
29
31
  bracket=None,
30
32
  bounds=None,
31
33
  tol: float | None = None,
34
+ prev_init: bool = False,
32
35
  options=None,
33
36
  ):
34
- defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
37
+ defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter, prev_init=prev_init)
35
38
  super().__init__(defaults)
36
39
 
37
40
  import scipy.optimize
@@ -42,11 +45,20 @@ class ScipyMinimizeScalar(LineSearchBase):
42
45
  def search(self, update, var):
43
46
  objective = self.make_objective(var=var)
44
47
  method, bracket, bounds, tol, options, maxiter = itemgetter(
45
- 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
48
+ 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)
46
49
 
47
50
  if maxiter is not None:
48
51
  options = dict(options) if isinstance(options, Mapping) else {}
49
52
  options['maxiter'] = maxiter
50
53
 
51
- res = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
52
- return res.x
54
+ if self.defaults["prev_init"] and "x_prev" in self.global_state:
55
+ if bracket is None: bracket = (0, 1)
56
+ bracket = (*bracket[:-1], self.global_state["x_prev"])
57
+
58
+ x = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options).x # pyright:ignore[reportAttributeAccessIssue]
59
+
60
+ max = torch.finfo(var.params[0].dtype).max / 2
61
+ if (not math.isfinite(x)) or abs(x) >= max: x = 0
62
+
63
+ self.global_state['x_prev'] = x
64
+ return x