torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,276 +1,377 @@
1
- """this needs to be reworked maybe but it also works"""
2
1
  import math
3
2
  import warnings
4
3
  from operator import itemgetter
4
+ from typing import Literal
5
5
 
6
+ import numpy as np
6
7
  import torch
7
8
  from torch.optim.lbfgs import _cubic_interpolate
8
9
 
9
- from .line_search import LineSearchBase
10
- from ...utils import totensor
10
+ from ...utils import as_tensorlist, totensor
11
+ from ._polyinterp import polyinterp, polyinterp2
12
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
13
+ from ..step_size.adaptive import _bb_geom
14
+
15
+ def _totensor(x):
16
+ if not isinstance(x, torch.Tensor): return torch.tensor(x, dtype=torch.float32)
17
+ return x
18
+
19
+ def _within_bounds(x, bounds):
20
+ if bounds is None: return True
21
+ lb,ub = bounds
22
+ if lb is not None and x < lb: return False
23
+ if ub is not None and x > ub: return False
24
+ return True
25
+
26
+ def _apply_bounds(x, bounds):
27
+ if bounds is None: return True
28
+ lb,ub = bounds
29
+ if lb is not None and x < lb: return lb
30
+ if ub is not None and x > ub: return ub
31
+ return x
32
+
33
+ class _StrongWolfe:
34
+ def __init__(
35
+ self,
36
+ f,
37
+ f_0,
38
+ g_0,
39
+ d_norm,
40
+ a_init,
41
+ a_max,
42
+ c1,
43
+ c2,
44
+ maxiter,
45
+ maxeval,
46
+ maxzoom,
47
+ tol_change,
48
+ interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", "polynomial2"],
49
+ ):
50
+ self._f = f
51
+ self.f_0 = f_0
52
+ self.g_0 = g_0
53
+ self.d_norm = d_norm
54
+ self.a_init = a_init
55
+ self.a_max = a_max
56
+ self.c1 = c1
57
+ self.c2 = c2
58
+ self.maxiter = maxiter
59
+ if maxeval is None: maxeval = float('inf')
60
+ self.maxeval = maxeval
61
+ self.tol_change = tol_change
62
+ self.num_evals = 0
63
+ self.maxzoom = maxzoom
64
+ self.interpolation = interpolation
65
+
66
+ self.history = {}
67
+
68
+ def f(self, a):
69
+ if a in self.history: return self.history[a]
70
+ self.num_evals += 1
71
+ f_a, g_a = self._f(a)
72
+ self.history[a] = (f_a, g_a)
73
+ return f_a, g_a
74
+
75
+ def interpolate(self, a_lo, f_lo, g_lo, a_hi, f_hi, g_hi, bounds=None):
76
+ if self.interpolation == 'cubic':
77
+ # pytorch cubic interpolate needs tensors
78
+ a_lo = _totensor(a_lo); f_lo = _totensor(f_lo); g_lo = _totensor(g_lo)
79
+ a_hi = _totensor(a_hi); f_hi = _totensor(f_hi); g_hi = _totensor(g_hi)
80
+ return float(_cubic_interpolate(x1=a_lo, f1=f_lo, g1=g_lo, x2=a_hi, f2=f_hi, g2=g_hi, bounds=bounds))
81
+
82
+ if self.interpolation == 'bisection':
83
+ return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
84
+
85
+ if self.interpolation == 'quadratic':
86
+ a = a_hi - a_lo
87
+ denom = 2 * (f_hi - f_lo - g_lo*a)
88
+ if denom > 1e-32:
89
+ num = g_lo * a**2
90
+ a_min = num / -denom
91
+ return _apply_bounds(a_min, bounds)
92
+ return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
93
+
94
+ if self.interpolation in ('polynomial', 'polynomial2'):
95
+ finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
96
+ if bounds is None: bounds = (None, None)
97
+ polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
98
+ try:
99
+ return _apply_bounds(polyinterp_fn(np.array(finite_history), *bounds), bounds) # pyright:ignore[reportArgumentType]
100
+ except torch.linalg.LinAlgError:
101
+ return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
102
+ else:
103
+ raise ValueError(self.interpolation)
104
+
105
+ def zoom(self, a_lo, f_lo, g_lo, a_hi, f_hi, g_hi):
106
+ if a_lo >= a_hi:
107
+ a_hi, f_hi, g_hi, a_lo, f_lo, g_lo = a_lo, f_lo, g_lo, a_hi, f_hi, g_hi
108
+
109
+ insuf_progress = False
110
+ for _ in range(self.maxzoom):
111
+ if self.num_evals >= self.maxeval: break
112
+ if (a_hi - a_lo) * self.d_norm < self.tol_change: break # small bracket
113
+
114
+ if not (math.isfinite(f_hi) and math.isfinite(g_hi)):
115
+ a_hi = a_hi / 2
116
+ f_hi, g_hi = self.f(a_hi)
117
+ continue
118
+
119
+ a_j = self.interpolate(a_lo, f_lo, g_lo, a_hi, f_hi, g_hi, bounds=(a_lo, min(a_hi, self.a_max)))
120
+
121
+ # this part is from https://github.com/pytorch/pytorch/blob/main/torch/optim/lbfgs.py:
122
+ eps = 0.1 * (a_hi - a_lo)
123
+ if min(a_hi - a_j, a_j - a_lo) < eps:
124
+ # interpolation close to boundary
125
+ if insuf_progress or a_j >= a_hi or a_j <= a_lo:
126
+ # evaluate at 0.1 away from boundary
127
+ if abs(a_j - a_hi) < abs(a_j - a_lo):
128
+ a_j = a_hi - eps
129
+ else:
130
+ a_j = a_lo + eps
131
+ insuf_progress = False
132
+ else:
133
+ insuf_progress = True
134
+ else:
135
+ insuf_progress = False
11
136
 
137
+ f_j, g_j = self.f(a_j)
12
138
 
13
- def _zoom(f,
14
- a_l, a_h,
15
- f_l, g_l,
16
- f_h, g_h,
17
- f_0, g_0,
18
- c1, c2,
19
- maxzoom):
139
+ if f_j > self.f_0 + self.c1*a_j*self.g_0 or f_j > f_lo:
140
+ a_hi, f_hi, g_hi = a_j, f_j, g_j
20
141
 
21
- for i in range(maxzoom):
22
- a_j = _cubic_interpolate(
23
- *(totensor(i) for i in (a_l, f_l, g_l, a_h, f_h, g_h))
142
+ else:
143
+ if abs(g_j) <= -self.c2 * self.g_0:
144
+ return a_j, f_j, g_j
24
145
 
25
- )
146
+ if g_j * (a_hi - a_lo) >= 0:
147
+ a_hi, f_hi, g_hi = a_lo, f_lo, g_lo
26
148
 
27
- # if interpolation fails or produces endpoint, bisect
28
- delta = abs(a_h - a_l)
29
- if a_j is None or a_j == a_l or a_j == a_h:
30
- a_j = a_l + 0.5 * delta
149
+ a_lo, f_lo, g_lo = a_j, f_j, g_j
31
150
 
151
+ # fail
152
+ return None, None, None
32
153
 
33
- f_j, g_j = f(a_j)
154
+ def search(self):
155
+ a_i = min(self.a_init, self.a_max)
156
+ f_i = g_i = None
157
+ a_prev = 0
158
+ f_prev = self.f_0
159
+ g_prev = self.g_0
160
+ for i in range(self.maxiter):
161
+ if self.num_evals >= self.maxeval: break
162
+ f_i, g_i = self.f(a_i)
34
163
 
35
- # check armijo
36
- armijo = f_j <= f_0 + c1 * a_j * g_0
164
+ if f_i > self.f_0 + self.c1*a_i*self.g_0 or (i > 0 and f_i > f_prev):
165
+ return self.zoom(a_prev, f_prev, g_prev, a_i, f_i, g_i)
37
166
 
38
- # check strong wolfe
39
- wolfe = abs(g_j) <= c2 * abs(g_0)
167
+ if abs(g_i) <= -self.c2 * self.g_0:
168
+ return a_i, f_i, g_i
40
169
 
170
+ if g_i >= 0:
171
+ return self.zoom(a_i, f_i, g_i, a_prev, f_prev, g_prev)
172
+
173
+ # from pytorch
174
+ min_step = a_i + 0.01 * (a_i - a_prev)
175
+ max_step = a_i * 10
176
+ a_i_next = self.interpolate(a_prev, f_prev, g_prev, a_i, f_i, g_i, bounds=(min_step, min(max_step, self.a_max)))
177
+ # a_i_next = self.interpolate(a_prev, f_prev, g_prev, a_i, f_i, g_i, bounds=(0, self.a_max))
178
+
179
+ a_prev, f_prev, g_prev = a_i, f_i, g_i
180
+ a_i = a_i_next
181
+
182
+ if self.num_evals < self.maxeval:
183
+ assert f_i is not None and g_i is not None
184
+ return self.zoom(0, self.f_0, self.g_0, a_i, f_i, g_i)
185
+
186
+ return None, None, None
41
187
 
42
- # minimum between alpha_low and alpha_j
43
- if not armijo or f_j >= f_l:
44
- a_h = a_j
45
- f_h = f_j
46
- g_h = g_j
47
- else:
48
- # alpha_j satisfies armijo
49
- if wolfe:
50
- return a_j, f_j
51
-
52
- # minimum between alpha_j and alpha_high
53
- if g_j * (a_h - a_l) >= 0:
54
- # between alpha_low and alpha_j
55
- # a_h = a_l
56
- # f_h = f_l
57
- # g_h = g_l
58
- a_h = a_j
59
- f_h = f_j
60
- g_h = g_j
61
-
62
- # is this messing it up?
63
- else:
64
- a_l = a_j
65
- f_l = f_j
66
- g_l = g_j
67
-
68
-
69
-
70
-
71
- # check if interval too small
72
- delta = abs(a_h - a_l)
73
- if delta <= 1e-9 or delta <= 1e-6 * max(abs(a_l), abs(a_h)):
74
- l_satisfies_wolfe = (f_l <= f_0 + c1 * a_l * g_0) and (abs(g_l) <= c2 * abs(g_0))
75
- h_satisfies_wolfe = (f_h <= f_0 + c1 * a_h * g_0) and (abs(g_h) <= c2 * abs(g_0))
76
-
77
- if l_satisfies_wolfe and h_satisfies_wolfe: return a_l if f_l <= f_h else a_h, f_h
78
- if l_satisfies_wolfe: return a_l, f_l
79
- if h_satisfies_wolfe: return a_h, f_h
80
- if f_l <= f_0 + c1 * a_l * g_0: return a_l, f_l
81
- return None,None
82
-
83
- if a_j is None or a_j == a_l or a_j == a_h:
84
- a_j = a_l + 0.5 * delta
85
-
86
-
87
- return None,None
88
-
89
-
90
- def strong_wolfe(
91
- f,
92
- f_0,
93
- g_0,
94
- init: float = 1.0,
95
- c1: float = 1e-4,
96
- c2: float = 0.9,
97
- maxiter: int = 25,
98
- maxzoom: int = 15,
99
- # a_max: float = 1e30,
100
- expand: float = 2.0, # Factor to increase alpha in bracketing
101
- plus_minus: bool = False,
102
- ) -> tuple[float,float] | tuple[None,None]:
103
- a_prev = 0.0
104
-
105
- if g_0 == 0: return None,None
106
- if g_0 > 0:
107
- # if direction is not a descent direction, perform line search in opposite direction
108
- if plus_minus:
109
- def inverted_objective(alpha):
110
- l, g = f(-alpha)
111
- return l, -g
112
- a, v = strong_wolfe(
113
- inverted_objective,
114
- init=init,
115
- f_0=f_0,
116
- g_0=-g_0,
117
- c1=c1,
118
- c2=c2,
119
- maxiter=maxiter,
120
- # a_max=a_max,
121
- expand=expand,
122
- plus_minus=False,
123
- )
124
- if a is not None and v is not None: return -a, v
125
- return None, None
126
-
127
- f_prev = f_0
128
- g_prev = g_0
129
- a_cur = init
130
-
131
- # bracket
132
- for i in range(maxiter):
133
-
134
- f_cur, g_cur = f(a_cur)
135
-
136
- # check armijo
137
- armijo_violated = f_cur > f_0 + c1 * a_cur * g_0
138
- func_increased = f_cur >= f_prev and i > 0
139
-
140
- if armijo_violated or func_increased:
141
- return _zoom(f,
142
- a_prev, a_cur,
143
- f_prev, g_prev,
144
- f_cur, g_cur,
145
- f_0, g_0,
146
- c1, c2,
147
- maxzoom=maxzoom,
148
- )
149
-
150
-
151
-
152
- # check strong wolfe
153
- if abs(g_cur) <= c2 * abs(g_0):
154
- return a_cur, f_cur
155
-
156
- # minimum is bracketed
157
- if g_cur >= 0:
158
- return _zoom(f,
159
- #alpha_curr, alpha_prev,
160
- a_prev, a_cur,
161
- #phi_curr, phi_prime_curr,
162
- f_prev, g_prev,
163
- f_cur, g_cur,
164
- f_0, g_0,
165
- c1, c2,
166
- maxzoom=maxzoom,)
167
-
168
- # otherwise continue bracketing
169
- a_next = a_cur * expand
170
-
171
- # update previous point and continue loop with increased step size
172
- a_prev = a_cur
173
- f_prev = f_cur
174
- g_prev = g_cur
175
- a_cur = a_next
176
-
177
-
178
- # max iters reached
179
- return None, None
180
-
181
- def _notfinite(x):
182
- if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
183
- return not math.isfinite(x)
184
188
 
185
189
  class StrongWolfe(LineSearchBase):
186
- """Cubic interpolation line search satisfying Strong Wolfe condition.
190
+ """Interpolation line search satisfying Strong Wolfe condition.
187
191
 
188
192
  Args:
189
- init (float, optional): Initial step size. Defaults to 1.0.
190
- c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
191
- c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
192
- maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
193
- maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
194
- expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
195
- use_prev (bool, optional):
196
- if True, previous step size is used as the initial step size on the next step.
193
+ c1 (float, optional): sufficient descent condition. Defaults to 1e-4.
194
+ c2 (float, optional): strong curvature condition. For CG set to 0.1. Defaults to 0.9.
195
+ a_init (str, optional):
196
+ strategy for initializing the initial step size guess.
197
+ - "fixed" - uses a fixed value specified in `init_value` argument.
198
+ - "first-order" - assumes first-order change in the function at iterate will be the same as that obtained at the previous step.
199
+ - "quadratic" - interpolates quadratic to f(x_{-1}) and f_x.
200
+ - "quadratic-clip" - same as quad, but uses min(1, 1.01*alpha) as described in Numerical Optimization.
201
+ - "previous" - uses final step size found on previous iteration.
202
+
203
+ For 2nd order methods it is usually best to leave at "fixed".
204
+ For methods that do not produce well scaled search directions, e.g. conjugate gradient,
205
+ "first-order" or "quadratic-clip" are recommended. Defaults to 'init'.
206
+ a_max (float, optional): upper bound for the proposed step sizes. Defaults to 1e12.
207
+ init_value (float, optional):
208
+ initial step size. Used when ``a_init``="fixed", and with other strategies as fallback value. Defaults to 1.
209
+ maxiter (int, optional): maximum number of line search iterations. Defaults to 25.
210
+ maxzoom (int, optional): maximum number of zoom iterations. Defaults to 10.
211
+ maxeval (int | None, optional): maximum number of function evaluations. Defaults to None.
212
+ tol_change (float, optional): tolerance, terminates on small brackets. Defaults to 1e-9.
213
+ interpolation (str, optional):
214
+ What type of interpolation to use.
215
+ - "bisection" - uses the middle point. This is robust, especially if the objective function is non-smooth, however it may need more function evaluations.
216
+ - "quadratic" - minimizes a quadratic model, generally outperformed by "cubic".
217
+ - "cubic" - minimizes a cubic model - this is the most widely used interpolation strategy.
218
+ - "polynomial" - fits a a polynomial to all points obtained during line search.
219
+ - "polynomial2" - alternative polynomial fit, where if a point is outside of bounds, a lower degree polynomial is tried.
220
+ This may have faster convergence than "cubic" and "polynomial".
221
+
222
+ Defaults to 'cubic'.
197
223
  adaptive (bool, optional):
198
- when enabled, if line search failed, initial step size is reduced.
199
- Otherwise it is reset to initial value. Defaults to True.
224
+ if True, the initial step size will be halved when line search failed to find a good direction.
225
+ When a good direction is found, initial step size is reset to the original value. Defaults to True.
226
+ fallback (bool, optional):
227
+ if True, when no point satisfied strong wolfe criteria,
228
+ returns a point with value lower than initial value that doesn't satisfy the criteria. Defaults to False.
200
229
  plus_minus (bool, optional):
201
- If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
230
+ if True, enables the plus-minus variant, where if curvature is negative, line search is performed
231
+ in the opposite direction. Defaults to False.
202
232
 
203
233
 
204
- Examples:
205
- Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
234
+ ## Examples:
206
235
 
207
- .. code-block:: python
236
+ Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by ``a_init="first-order"``.
208
237
 
209
- opt = tz.Modular(
210
- model.parameters(),
211
- tz.m.PolakRibiere(),
212
- tz.m.StrongWolfe(c2=0.1)
213
- )
238
+ ```python
239
+ opt = tz.Modular(
240
+ model.parameters(),
241
+ tz.m.PolakRibiere(),
242
+ tz.m.StrongWolfe(c2=0.1, a_init="first-order")
243
+ )
244
+ ```
214
245
 
215
- LBFGS strong wolfe line search:
216
-
217
- .. code-block:: python
218
-
219
- opt = tz.Modular(
220
- model.parameters(),
221
- tz.m.LBFGS(),
222
- tz.m.StrongWolfe()
223
- )
246
+ LBFGS strong wolfe line search:
247
+ ```python
248
+ opt = tz.Modular(
249
+ model.parameters(),
250
+ tz.m.LBFGS(),
251
+ tz.m.StrongWolfe()
252
+ )
253
+ ```
224
254
 
225
255
  """
226
256
  def __init__(
227
257
  self,
228
- init: float = 1.0,
229
258
  c1: float = 1e-4,
230
259
  c2: float = 0.9,
260
+ a_init: Literal['first-order', 'quadratic', 'quadratic-clip', 'previous', 'fixed'] = 'fixed',
261
+ a_max: float = 1e12,
262
+ init_value: float = 1,
231
263
  maxiter: int = 25,
232
264
  maxzoom: int = 10,
233
- # a_max: float = 1e10,
234
- expand: float = 2.0,
235
- use_prev: bool = False,
265
+ maxeval: int | None = None,
266
+ tol_change: float = 1e-9,
267
+ interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", 'polynomial2'] = 'cubic',
236
268
  adaptive = True,
269
+ fallback:bool = False,
237
270
  plus_minus = False,
238
271
  ):
239
- defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
240
- expand=expand, adaptive=adaptive, plus_minus=plus_minus,use_prev=use_prev)
272
+ defaults=dict(init_value=init_value,init=a_init,a_max=a_max,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom, fallback=fallback,
273
+ maxeval=maxeval, adaptive=adaptive, interpolation=interpolation, plus_minus=plus_minus, tol_change=tol_change)
241
274
  super().__init__(defaults=defaults)
242
275
 
243
276
  self.global_state['initial_scale'] = 1.0
244
- self.global_state['beta_scale'] = 1.0
245
277
 
246
278
  @torch.no_grad
247
279
  def search(self, update, var):
280
+ self._g_prev = self._f_prev = None
248
281
  objective = self.make_objective_with_derivative(var=var)
249
282
 
250
- init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus, use_prev = itemgetter(
251
- 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
252
- 'expand', 'adaptive', 'plus_minus', 'use_prev')(self.settings[var.params[0]])
283
+ init_value, init, c1, c2, a_max, maxiter, maxzoom, maxeval, interpolation, adaptive, plus_minus, fallback, tol_change = itemgetter(
284
+ 'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
285
+ 'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)
286
+
287
+ dir = as_tensorlist(var.get_update())
288
+ grad_list = var.get_grad()
289
+
290
+ g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
291
+ f_0 = var.get_loss(False)
292
+ dir_norm = dir.global_vector_norm()
293
+
294
+ inverted = False
295
+ if plus_minus and g_0 > 0:
296
+ original_objective = objective
297
+ def inverted_objective(a):
298
+ l, g_a = original_objective(-a)
299
+ return l, -g_a
300
+ objective = inverted_objective
301
+ inverted = True
302
+
303
+ # --------------------- determine initial step size guess -------------------- #
304
+ init = init.lower().strip()
305
+
306
+ a_init = init_value
307
+ if init == 'fixed':
308
+ pass # use init_value
309
+
310
+ elif init == 'previous':
311
+ if 'a_prev' in self.global_state:
312
+ a_init = self.global_state['a_prev']
313
+
314
+ elif init == 'first-order':
315
+ if 'g_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
316
+ a_prev = self.global_state['a_prev']
317
+ g_prev = self.global_state['g_prev']
318
+ if g_prev < 0:
319
+ a_init = a_prev * g_prev / g_0
320
+
321
+ elif init in ('quadratic', 'quadratic-clip'):
322
+ if 'f_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
323
+ f_prev = self.global_state['f_prev']
324
+ if f_0 < f_prev:
325
+ a_init = 2 * (f_0 - f_prev) / g_0
326
+ if init == 'quadratic-clip': a_init = min(1, 1.01*a_init)
327
+ else:
328
+ raise ValueError(init)
329
+
330
+ if adaptive:
331
+ a_init *= self.global_state.get('initial_scale', 1)
253
332
 
254
- f_0, g_0 = objective(0)
255
- if use_prev: init = self.global_state.get('prev_alpha', init)
256
333
 
257
- step_size,f_a = strong_wolfe(
258
- objective,
259
- f_0=f_0, g_0=g_0,
260
- init=init * self.global_state.setdefault("initial_scale", 1),
334
+ strong_wolfe = _StrongWolfe(
335
+ f=objective,
336
+ f_0=f_0,
337
+ g_0=g_0,
338
+ d_norm=dir_norm,
339
+ a_init=a_init,
340
+ a_max=a_max,
261
341
  c1=c1,
262
342
  c2=c2,
263
343
  maxiter=maxiter,
264
344
  maxzoom=maxzoom,
265
- expand=expand,
266
- plus_minus=plus_minus,
345
+ maxeval=maxeval,
346
+ tol_change=tol_change,
347
+ interpolation=interpolation,
267
348
  )
268
349
 
269
- if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
270
- if step_size is not None and step_size != 0 and not _notfinite(step_size):
271
- self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
272
- self.global_state['prev_alpha'] = step_size
273
- return step_size
350
+ a, f_a, g_a = strong_wolfe.search()
351
+ if inverted and a is not None: a = -a
352
+ if f_a is not None and (f_a > f_0 or not math.isfinite(f_a)): a = None
353
+
354
+ if fallback:
355
+ if a is None or a==0 or not math.isfinite(a):
356
+ lowest = min(strong_wolfe.history.items(), key=lambda x: x[1][0])
357
+ if lowest[1][0] < f_0:
358
+ a = lowest[0]
359
+ f_a, g_a = lowest[1]
360
+ if inverted: a = -a
361
+
362
+ if a is not None and a != 0 and math.isfinite(a):
363
+ #self.global_state['initial_scale'] = min(1.0, self.global_state.get('initial_scale', 1) * math.sqrt(2))
364
+ self.global_state['initial_scale'] = 1
365
+ self.global_state['a_prev'] = a
366
+ self.global_state['f_prev'] = f_0
367
+ self.global_state['g_prev'] = g_0
368
+ return a
369
+
370
+ # fail
371
+ if adaptive:
372
+ self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
373
+ finfo = torch.finfo(dir[0].dtype)
374
+ if self.global_state['initial_scale'] < finfo.tiny * 2:
375
+ self.global_state['initial_scale'] = finfo.max / 2
274
376
 
275
- if adaptive: self.global_state['initial_scale'] *= 0.5
276
377
  return 0
@@ -1,6 +1,13 @@
1
1
  from .debug import PrintLoss, PrintParams, PrintShape, PrintUpdate
2
2
  from .escape import EscapeAnnealing
3
3
  from .gradient_accumulation import GradientAccumulation
4
+ from .homotopy import (
5
+ ExpHomotopy,
6
+ LambdaHomotopy,
7
+ LogHomotopy,
8
+ SqrtHomotopy,
9
+ SquareHomotopy,
10
+ )
4
11
  from .misc import (
5
12
  DivByLoss,
6
13
  FillLoss,
@@ -20,6 +27,7 @@ from .misc import (
20
27
  RandomHvp,
21
28
  Relative,
22
29
  UpdateSign,
30
+ SaveBest,
23
31
  )
24
32
  from .multistep import Multistep, NegateOnLossIncrease, Online, Sequential
25
33
  from .regularization import Dropout, PerturbWeights, WeightDropout
@@ -12,7 +12,7 @@ class PrintUpdate(Module):
12
12
  super().__init__(defaults)
13
13
 
14
14
  def step(self, var):
15
- self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.update}')
15
+ self.defaults["print_fn"](f'{self.defaults["text"]}{var.update}')
16
16
  return var
17
17
 
18
18
  class PrintShape(Module):
@@ -23,7 +23,7 @@ class PrintShape(Module):
23
23
 
24
24
  def step(self, var):
25
25
  shapes = [u.shape for u in var.update] if var.update is not None else None
26
- self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
26
+ self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
27
27
  return var
28
28
 
29
29
  class PrintParams(Module):
@@ -33,7 +33,7 @@ class PrintParams(Module):
33
33
  super().__init__(defaults)
34
34
 
35
35
  def step(self, var):
36
- self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.params}')
36
+ self.defaults["print_fn"](f'{self.defaults["text"]}{var.params}')
37
37
  return var
38
38
 
39
39
 
@@ -44,5 +44,5 @@ class PrintLoss(Module):
44
44
  super().__init__(defaults)
45
45
 
46
46
  def step(self, var):
47
- self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.get_loss(False)}')
47
+ self.defaults["print_fn"](f'{self.defaults["text"]}{var.get_loss(False)}')
48
48
  return var