torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,249 +1,377 @@
1
1
  import math
2
2
  import warnings
3
3
  from operator import itemgetter
4
+ from typing import Literal
4
5
 
6
+ import numpy as np
5
7
  import torch
6
8
  from torch.optim.lbfgs import _cubic_interpolate
7
9
 
8
- from .line_search import LineSearch
9
- from .backtracking import backtracking_line_search
10
- from ...utils import totensor
10
+ from ...utils import as_tensorlist, totensor
11
+ from ._polyinterp import polyinterp, polyinterp2
12
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
13
+ from ..step_size.adaptive import _bb_geom
14
+
15
+ def _totensor(x):
16
+ if not isinstance(x, torch.Tensor): return torch.tensor(x, dtype=torch.float32)
17
+ return x
18
+
19
+ def _within_bounds(x, bounds):
20
+ if bounds is None: return True
21
+ lb,ub = bounds
22
+ if lb is not None and x < lb: return False
23
+ if ub is not None and x > ub: return False
24
+ return True
25
+
26
+ def _apply_bounds(x, bounds):
27
+ if bounds is None: return True
28
+ lb,ub = bounds
29
+ if lb is not None and x < lb: return lb
30
+ if ub is not None and x > ub: return ub
31
+ return x
32
+
33
+ class _StrongWolfe:
34
+ def __init__(
35
+ self,
36
+ f,
37
+ f_0,
38
+ g_0,
39
+ d_norm,
40
+ a_init,
41
+ a_max,
42
+ c1,
43
+ c2,
44
+ maxiter,
45
+ maxeval,
46
+ maxzoom,
47
+ tol_change,
48
+ interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", "polynomial2"],
49
+ ):
50
+ self._f = f
51
+ self.f_0 = f_0
52
+ self.g_0 = g_0
53
+ self.d_norm = d_norm
54
+ self.a_init = a_init
55
+ self.a_max = a_max
56
+ self.c1 = c1
57
+ self.c2 = c2
58
+ self.maxiter = maxiter
59
+ if maxeval is None: maxeval = float('inf')
60
+ self.maxeval = maxeval
61
+ self.tol_change = tol_change
62
+ self.num_evals = 0
63
+ self.maxzoom = maxzoom
64
+ self.interpolation = interpolation
65
+
66
+ self.history = {}
67
+
68
+ def f(self, a):
69
+ if a in self.history: return self.history[a]
70
+ self.num_evals += 1
71
+ f_a, g_a = self._f(a)
72
+ self.history[a] = (f_a, g_a)
73
+ return f_a, g_a
74
+
75
+ def interpolate(self, a_lo, f_lo, g_lo, a_hi, f_hi, g_hi, bounds=None):
76
+ if self.interpolation == 'cubic':
77
+ # pytorch cubic interpolate needs tensors
78
+ a_lo = _totensor(a_lo); f_lo = _totensor(f_lo); g_lo = _totensor(g_lo)
79
+ a_hi = _totensor(a_hi); f_hi = _totensor(f_hi); g_hi = _totensor(g_hi)
80
+ return float(_cubic_interpolate(x1=a_lo, f1=f_lo, g1=g_lo, x2=a_hi, f2=f_hi, g2=g_hi, bounds=bounds))
81
+
82
+ if self.interpolation == 'bisection':
83
+ return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
84
+
85
+ if self.interpolation == 'quadratic':
86
+ a = a_hi - a_lo
87
+ denom = 2 * (f_hi - f_lo - g_lo*a)
88
+ if denom > 1e-32:
89
+ num = g_lo * a**2
90
+ a_min = num / -denom
91
+ return _apply_bounds(a_min, bounds)
92
+ return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
93
+
94
+ if self.interpolation in ('polynomial', 'polynomial2'):
95
+ finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
96
+ if bounds is None: bounds = (None, None)
97
+ polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
98
+ try:
99
+ return _apply_bounds(polyinterp_fn(np.array(finite_history), *bounds), bounds) # pyright:ignore[reportArgumentType]
100
+ except torch.linalg.LinAlgError:
101
+ return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
102
+ else:
103
+ raise ValueError(self.interpolation)
104
+
105
+ def zoom(self, a_lo, f_lo, g_lo, a_hi, f_hi, g_hi):
106
+ if a_lo >= a_hi:
107
+ a_hi, f_hi, g_hi, a_lo, f_lo, g_lo = a_lo, f_lo, g_lo, a_hi, f_hi, g_hi
108
+
109
+ insuf_progress = False
110
+ for _ in range(self.maxzoom):
111
+ if self.num_evals >= self.maxeval: break
112
+ if (a_hi - a_lo) * self.d_norm < self.tol_change: break # small bracket
113
+
114
+ if not (math.isfinite(f_hi) and math.isfinite(g_hi)):
115
+ a_hi = a_hi / 2
116
+ f_hi, g_hi = self.f(a_hi)
117
+ continue
118
+
119
+ a_j = self.interpolate(a_lo, f_lo, g_lo, a_hi, f_hi, g_hi, bounds=(a_lo, min(a_hi, self.a_max)))
120
+
121
+ # this part is from https://github.com/pytorch/pytorch/blob/main/torch/optim/lbfgs.py:
122
+ eps = 0.1 * (a_hi - a_lo)
123
+ if min(a_hi - a_j, a_j - a_lo) < eps:
124
+ # interpolation close to boundary
125
+ if insuf_progress or a_j >= a_hi or a_j <= a_lo:
126
+ # evaluate at 0.1 away from boundary
127
+ if abs(a_j - a_hi) < abs(a_j - a_lo):
128
+ a_j = a_hi - eps
129
+ else:
130
+ a_j = a_lo + eps
131
+ insuf_progress = False
132
+ else:
133
+ insuf_progress = True
134
+ else:
135
+ insuf_progress = False
11
136
 
137
+ f_j, g_j = self.f(a_j)
12
138
 
13
- def _zoom(f,
14
- a_l, a_h,
15
- f_l, g_l,
16
- f_h, g_h,
17
- f_0, g_0,
18
- c1, c2,
19
- maxzoom):
139
+ if f_j > self.f_0 + self.c1*a_j*self.g_0 or f_j > f_lo:
140
+ a_hi, f_hi, g_hi = a_j, f_j, g_j
20
141
 
21
- for i in range(maxzoom):
22
- a_j = _cubic_interpolate(
23
- *(totensor(i) for i in (a_l, f_l, g_l, a_h, f_h, g_h))
142
+ else:
143
+ if abs(g_j) <= -self.c2 * self.g_0:
144
+ return a_j, f_j, g_j
24
145
 
25
- )
146
+ if g_j * (a_hi - a_lo) >= 0:
147
+ a_hi, f_hi, g_hi = a_lo, f_lo, g_lo
26
148
 
27
- # if interpolation fails or produces endpoint, bisect
28
- delta = abs(a_h - a_l)
29
- if a_j is None or a_j == a_l or a_j == a_h:
30
- a_j = a_l + 0.5 * delta
149
+ a_lo, f_lo, g_lo = a_j, f_j, g_j
31
150
 
151
+ # fail
152
+ return None, None, None
32
153
 
33
- f_j, g_j = f(a_j)
154
+ def search(self):
155
+ a_i = min(self.a_init, self.a_max)
156
+ f_i = g_i = None
157
+ a_prev = 0
158
+ f_prev = self.f_0
159
+ g_prev = self.g_0
160
+ for i in range(self.maxiter):
161
+ if self.num_evals >= self.maxeval: break
162
+ f_i, g_i = self.f(a_i)
34
163
 
35
- # check armijo
36
- armijo = f_j <= f_0 + c1 * a_j * g_0
164
+ if f_i > self.f_0 + self.c1*a_i*self.g_0 or (i > 0 and f_i > f_prev):
165
+ return self.zoom(a_prev, f_prev, g_prev, a_i, f_i, g_i)
37
166
 
38
- # check strong wolfe
39
- wolfe = abs(g_j) <= c2 * abs(g_0)
167
+ if abs(g_i) <= -self.c2 * self.g_0:
168
+ return a_i, f_i, g_i
40
169
 
170
+ if g_i >= 0:
171
+ return self.zoom(a_i, f_i, g_i, a_prev, f_prev, g_prev)
41
172
 
42
- # minimum between alpha_low and alpha_j
43
- if not armijo or f_j >= f_l:
44
- a_h = a_j
45
- f_h = f_j
46
- g_h = g_j
47
- else:
48
- # alpha_j satisfies armijo
49
- if wolfe:
50
- return a_j, f_j
51
-
52
- # minimum between alpha_j and alpha_high
53
- if g_j * (a_h - a_l) >= 0:
54
- # between alpha_low and alpha_j
55
- # a_h = a_l
56
- # f_h = f_l
57
- # g_h = g_l
58
- a_h = a_j
59
- f_h = f_j
60
- g_h = g_j
61
-
62
- # is this messing it up?
63
- else:
64
- a_l = a_j
65
- f_l = f_j
66
- g_l = g_j
67
-
68
-
69
-
70
-
71
- # check if interval too small
72
- delta = abs(a_h - a_l)
73
- if delta <= 1e-9 or delta <= 1e-6 * max(abs(a_l), abs(a_h)):
74
- l_satisfies_wolfe = (f_l <= f_0 + c1 * a_l * g_0) and (abs(g_l) <= c2 * abs(g_0))
75
- h_satisfies_wolfe = (f_h <= f_0 + c1 * a_h * g_0) and (abs(g_h) <= c2 * abs(g_0))
76
-
77
- if l_satisfies_wolfe and h_satisfies_wolfe: return a_l if f_l <= f_h else a_h, f_h
78
- if l_satisfies_wolfe: return a_l, f_l
79
- if h_satisfies_wolfe: return a_h, f_h
80
- if f_l <= f_0 + c1 * a_l * g_0: return a_l, f_l
81
- return None,None
82
-
83
- if a_j is None or a_j == a_l or a_j == a_h:
84
- a_j = a_l + 0.5 * delta
85
-
86
-
87
- return None,None
88
-
89
-
90
- def strong_wolfe(
91
- f,
92
- f_0,
93
- g_0,
94
- init: float = 1.0,
95
- c1: float = 1e-4,
96
- c2: float = 0.9,
97
- maxiter: int = 25,
98
- maxzoom: int = 15,
99
- # a_max: float = 1e30,
100
- expand: float = 2.0, # Factor to increase alpha in bracketing
101
- plus_minus: bool = False,
102
- ) -> tuple[float,float] | tuple[None,None]:
103
- a_prev = 0.0
104
-
105
- if g_0 == 0: return None,None
106
- if g_0 > 0:
107
- # if direction is not a descent direction, perform line search in opposite direction
108
- if plus_minus:
109
- def inverted_objective(alpha):
110
- l, g = f(-alpha)
111
- return l, -g
112
- a, v = strong_wolfe(
113
- inverted_objective,
114
- init=init,
115
- f_0=f_0,
116
- g_0=-g_0,
117
- c1=c1,
118
- c2=c2,
119
- maxiter=maxiter,
120
- # a_max=a_max,
121
- expand=expand,
122
- plus_minus=False,
123
- )
124
- if a is not None and v is not None: return -a, v
125
- return None, None
126
-
127
- f_prev = f_0
128
- g_prev = g_0
129
- a_cur = init
130
-
131
- # bracket
132
- for i in range(maxiter):
133
-
134
- f_cur, g_cur = f(a_cur)
135
-
136
- # check armijo
137
- armijo_violated = f_cur > f_0 + c1 * a_cur * g_0
138
- func_increased = f_cur >= f_prev and i > 0
139
-
140
- if armijo_violated or func_increased:
141
- return _zoom(f,
142
- a_prev, a_cur,
143
- f_prev, g_prev,
144
- f_cur, g_cur,
145
- f_0, g_0,
146
- c1, c2,
147
- maxzoom=maxzoom,
148
- )
149
-
150
-
151
-
152
- # check strong wolfe
153
- if abs(g_cur) <= c2 * abs(g_0):
154
- return a_cur, f_cur
155
-
156
- # minimum is bracketed
157
- if g_cur >= 0:
158
- return _zoom(f,
159
- #alpha_curr, alpha_prev,
160
- a_prev, a_cur,
161
- #phi_curr, phi_prime_curr,
162
- f_prev, g_prev,
163
- f_cur, g_cur,
164
- f_0, g_0,
165
- c1, c2,
166
- maxzoom=maxzoom,)
167
-
168
- # otherwise continue bracketing
169
- a_next = a_cur * expand
170
-
171
- # update previous point and continue loop with increased step size
172
- a_prev = a_cur
173
- f_prev = f_cur
174
- g_prev = g_cur
175
- a_cur = a_next
176
-
177
-
178
- # max iters reached
179
- return None, None
180
-
181
- def _notfinite(x):
182
- if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
183
- return not math.isfinite(x)
184
-
185
- class StrongWolfe(LineSearch):
186
- """Cubic interpolation line search satisfying Strong Wolfe condition.
173
+ # from pytorch
174
+ min_step = a_i + 0.01 * (a_i - a_prev)
175
+ max_step = a_i * 10
176
+ a_i_next = self.interpolate(a_prev, f_prev, g_prev, a_i, f_i, g_i, bounds=(min_step, min(max_step, self.a_max)))
177
+ # a_i_next = self.interpolate(a_prev, f_prev, g_prev, a_i, f_i, g_i, bounds=(0, self.a_max))
178
+
179
+ a_prev, f_prev, g_prev = a_i, f_i, g_i
180
+ a_i = a_i_next
181
+
182
+ if self.num_evals < self.maxeval:
183
+ assert f_i is not None and g_i is not None
184
+ return self.zoom(0, self.f_0, self.g_0, a_i, f_i, g_i)
185
+
186
+ return None, None, None
187
+
188
+
189
+ class StrongWolfe(LineSearchBase):
190
+ """Interpolation line search satisfying Strong Wolfe condition.
187
191
 
188
192
  Args:
189
- init (float, optional): Initial step size. Defaults to 1.0.
190
- c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
191
- c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
192
- maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
193
- maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
194
- expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
193
+ c1 (float, optional): sufficient descent condition. Defaults to 1e-4.
194
+ c2 (float, optional): strong curvature condition. For CG set to 0.1. Defaults to 0.9.
195
+ a_init (str, optional):
196
+ strategy for initializing the initial step size guess.
197
+ - "fixed" - uses a fixed value specified in `init_value` argument.
198
+ - "first-order" - assumes first-order change in the function at iterate will be the same as that obtained at the previous step.
199
+ - "quadratic" - interpolates quadratic to f(x_{-1}) and f_x.
200
+ - "quadratic-clip" - same as quad, but uses min(1, 1.01*alpha) as described in Numerical Optimization.
201
+ - "previous" - uses final step size found on previous iteration.
202
+
203
+ For 2nd order methods it is usually best to leave at "fixed".
204
+ For methods that do not produce well scaled search directions, e.g. conjugate gradient,
205
+ "first-order" or "quadratic-clip" are recommended. Defaults to 'init'.
206
+ a_max (float, optional): upper bound for the proposed step sizes. Defaults to 1e12.
207
+ init_value (float, optional):
208
+ initial step size. Used when ``a_init``="fixed", and with other strategies as fallback value. Defaults to 1.
209
+ maxiter (int, optional): maximum number of line search iterations. Defaults to 25.
210
+ maxzoom (int, optional): maximum number of zoom iterations. Defaults to 10.
211
+ maxeval (int | None, optional): maximum number of function evaluations. Defaults to None.
212
+ tol_change (float, optional): tolerance, terminates on small brackets. Defaults to 1e-9.
213
+ interpolation (str, optional):
214
+ What type of interpolation to use.
215
+ - "bisection" - uses the middle point. This is robust, especially if the objective function is non-smooth, however it may need more function evaluations.
216
+ - "quadratic" - minimizes a quadratic model, generally outperformed by "cubic".
217
+ - "cubic" - minimizes a cubic model - this is the most widely used interpolation strategy.
218
+ - "polynomial" - fits a a polynomial to all points obtained during line search.
219
+ - "polynomial2" - alternative polynomial fit, where if a point is outside of bounds, a lower degree polynomial is tried.
220
+ This may have faster convergence than "cubic" and "polynomial".
221
+
222
+ Defaults to 'cubic'.
195
223
  adaptive (bool, optional):
196
- when enabled, if line search failed, initial step size is reduced.
197
- Otherwise it is reset to initial value. Defaults to True.
224
+ if True, the initial step size will be halved when line search failed to find a good direction.
225
+ When a good direction is found, initial step size is reset to the original value. Defaults to True.
226
+ fallback (bool, optional):
227
+ if True, when no point satisfied strong wolfe criteria,
228
+ returns a point with value lower than initial value that doesn't satisfy the criteria. Defaults to False.
198
229
  plus_minus (bool, optional):
199
- If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
230
+ if True, enables the plus-minus variant, where if curvature is negative, line search is performed
231
+ in the opposite direction. Defaults to False.
232
+
233
+
234
+ ## Examples:
235
+
236
+ Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by ``a_init="first-order"``.
237
+
238
+ ```python
239
+ opt = tz.Modular(
240
+ model.parameters(),
241
+ tz.m.PolakRibiere(),
242
+ tz.m.StrongWolfe(c2=0.1, a_init="first-order")
243
+ )
244
+ ```
245
+
246
+ LBFGS strong wolfe line search:
247
+ ```python
248
+ opt = tz.Modular(
249
+ model.parameters(),
250
+ tz.m.LBFGS(),
251
+ tz.m.StrongWolfe()
252
+ )
253
+ ```
254
+
200
255
  """
201
256
  def __init__(
202
257
  self,
203
- init: float = 1.0,
204
258
  c1: float = 1e-4,
205
259
  c2: float = 0.9,
260
+ a_init: Literal['first-order', 'quadratic', 'quadratic-clip', 'previous', 'fixed'] = 'fixed',
261
+ a_max: float = 1e12,
262
+ init_value: float = 1,
206
263
  maxiter: int = 25,
207
264
  maxzoom: int = 10,
208
- # a_max: float = 1e10,
209
- expand: float = 2.0,
265
+ maxeval: int | None = None,
266
+ tol_change: float = 1e-9,
267
+ interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", 'polynomial2'] = 'cubic',
210
268
  adaptive = True,
269
+ fallback:bool = False,
211
270
  plus_minus = False,
212
271
  ):
213
- defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
214
- expand=expand, adaptive=adaptive, plus_minus=plus_minus)
272
+ defaults=dict(init_value=init_value,init=a_init,a_max=a_max,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom, fallback=fallback,
273
+ maxeval=maxeval, adaptive=adaptive, interpolation=interpolation, plus_minus=plus_minus, tol_change=tol_change)
215
274
  super().__init__(defaults=defaults)
216
275
 
217
276
  self.global_state['initial_scale'] = 1.0
218
- self.global_state['beta_scale'] = 1.0
219
277
 
220
278
  @torch.no_grad
221
279
  def search(self, update, var):
280
+ self._g_prev = self._f_prev = None
222
281
  objective = self.make_objective_with_derivative(var=var)
223
282
 
224
- init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
225
- 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
226
- 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
283
+ init_value, init, c1, c2, a_max, maxiter, maxzoom, maxeval, interpolation, adaptive, plus_minus, fallback, tol_change = itemgetter(
284
+ 'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
285
+ 'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)
286
+
287
+ dir = as_tensorlist(var.get_update())
288
+ grad_list = var.get_grad()
289
+
290
+ g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
291
+ f_0 = var.get_loss(False)
292
+ dir_norm = dir.global_vector_norm()
293
+
294
+ inverted = False
295
+ if plus_minus and g_0 > 0:
296
+ original_objective = objective
297
+ def inverted_objective(a):
298
+ l, g_a = original_objective(-a)
299
+ return l, -g_a
300
+ objective = inverted_objective
301
+ inverted = True
302
+
303
+ # --------------------- determine initial step size guess -------------------- #
304
+ init = init.lower().strip()
305
+
306
+ a_init = init_value
307
+ if init == 'fixed':
308
+ pass # use init_value
309
+
310
+ elif init == 'previous':
311
+ if 'a_prev' in self.global_state:
312
+ a_init = self.global_state['a_prev']
313
+
314
+ elif init == 'first-order':
315
+ if 'g_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
316
+ a_prev = self.global_state['a_prev']
317
+ g_prev = self.global_state['g_prev']
318
+ if g_prev < 0:
319
+ a_init = a_prev * g_prev / g_0
320
+
321
+ elif init in ('quadratic', 'quadratic-clip'):
322
+ if 'f_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
323
+ f_prev = self.global_state['f_prev']
324
+ if f_0 < f_prev:
325
+ a_init = 2 * (f_0 - f_prev) / g_0
326
+ if init == 'quadratic-clip': a_init = min(1, 1.01*a_init)
327
+ else:
328
+ raise ValueError(init)
329
+
330
+ if adaptive:
331
+ a_init *= self.global_state.get('initial_scale', 1)
227
332
 
228
- f_0, g_0 = objective(0)
229
333
 
230
- step_size,f_a = strong_wolfe(
231
- objective,
232
- f_0=f_0, g_0=g_0,
233
- init=init * self.global_state.setdefault("initial_scale", 1),
334
+ strong_wolfe = _StrongWolfe(
335
+ f=objective,
336
+ f_0=f_0,
337
+ g_0=g_0,
338
+ d_norm=dir_norm,
339
+ a_init=a_init,
340
+ a_max=a_max,
234
341
  c1=c1,
235
342
  c2=c2,
236
343
  maxiter=maxiter,
237
344
  maxzoom=maxzoom,
238
- expand=expand,
239
- plus_minus=plus_minus,
345
+ maxeval=maxeval,
346
+ tol_change=tol_change,
347
+ interpolation=interpolation,
240
348
  )
241
349
 
242
- if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
243
- if step_size is not None and step_size != 0 and not _notfinite(step_size):
244
- self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
245
- return step_size
350
+ a, f_a, g_a = strong_wolfe.search()
351
+ if inverted and a is not None: a = -a
352
+ if f_a is not None and (f_a > f_0 or not math.isfinite(f_a)): a = None
353
+
354
+ if fallback:
355
+ if a is None or a==0 or not math.isfinite(a):
356
+ lowest = min(strong_wolfe.history.items(), key=lambda x: x[1][0])
357
+ if lowest[1][0] < f_0:
358
+ a = lowest[0]
359
+ f_a, g_a = lowest[1]
360
+ if inverted: a = -a
361
+
362
+ if a is not None and a != 0 and math.isfinite(a):
363
+ #self.global_state['initial_scale'] = min(1.0, self.global_state.get('initial_scale', 1) * math.sqrt(2))
364
+ self.global_state['initial_scale'] = 1
365
+ self.global_state['a_prev'] = a
366
+ self.global_state['f_prev'] = f_0
367
+ self.global_state['g_prev'] = g_0
368
+ return a
369
+
370
+ # fail
371
+ if adaptive:
372
+ self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
373
+ finfo = torch.finfo(dir[0].dtype)
374
+ if self.global_state['initial_scale'] < finfo.tiny * 2:
375
+ self.global_state['initial_scale'] = finfo.max / 2
246
376
 
247
- # fallback to backtracking on fail
248
- if adaptive: self.global_state['initial_scale'] *= 0.5
249
377
  return 0
@@ -0,0 +1,35 @@
1
+ from .debug import PrintLoss, PrintParams, PrintShape, PrintUpdate
2
+ from .escape import EscapeAnnealing
3
+ from .gradient_accumulation import GradientAccumulation
4
+ from .homotopy import (
5
+ ExpHomotopy,
6
+ LambdaHomotopy,
7
+ LogHomotopy,
8
+ SqrtHomotopy,
9
+ SquareHomotopy,
10
+ )
11
+ from .misc import (
12
+ DivByLoss,
13
+ FillLoss,
14
+ GradSign,
15
+ GraftGradToUpdate,
16
+ GraftToGrad,
17
+ GraftToParams,
18
+ HpuEstimate,
19
+ LastAbsoluteRatio,
20
+ LastDifference,
21
+ LastGradDifference,
22
+ LastProduct,
23
+ LastRatio,
24
+ MulByLoss,
25
+ NoiseSign,
26
+ Previous,
27
+ RandomHvp,
28
+ Relative,
29
+ UpdateSign,
30
+ SaveBest,
31
+ )
32
+ from .multistep import Multistep, NegateOnLossIncrease, Online, Sequential
33
+ from .regularization import Dropout, PerturbWeights, WeightDropout
34
+ from .split import Split
35
+ from .switch import Alternate, Switch
@@ -0,0 +1,48 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import Module
6
+ from ...utils.tensorlist import Distributions
7
+
8
+ class PrintUpdate(Module):
9
+ """Prints current update."""
10
+ def __init__(self, text = 'update = ', print_fn = print):
11
+ defaults = dict(text=text, print_fn=print_fn)
12
+ super().__init__(defaults)
13
+
14
+ def step(self, var):
15
+ self.defaults["print_fn"](f'{self.defaults["text"]}{var.update}')
16
+ return var
17
+
18
+ class PrintShape(Module):
19
+ """Prints shapes of the update."""
20
+ def __init__(self, text = 'shapes = ', print_fn = print):
21
+ defaults = dict(text=text, print_fn=print_fn)
22
+ super().__init__(defaults)
23
+
24
+ def step(self, var):
25
+ shapes = [u.shape for u in var.update] if var.update is not None else None
26
+ self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
27
+ return var
28
+
29
+ class PrintParams(Module):
30
+ """Prints current update."""
31
+ def __init__(self, text = 'params = ', print_fn = print):
32
+ defaults = dict(text=text, print_fn=print_fn)
33
+ super().__init__(defaults)
34
+
35
+ def step(self, var):
36
+ self.defaults["print_fn"](f'{self.defaults["text"]}{var.params}')
37
+ return var
38
+
39
+
40
+ class PrintLoss(Module):
41
+ """Prints var.get_loss()."""
42
+ def __init__(self, text = 'loss = ', print_fn = print):
43
+ defaults = dict(text=text, print_fn=print_fn)
44
+ super().__init__(defaults)
45
+
46
+ def step(self, var):
47
+ self.defaults["print_fn"](f'{self.defaults["text"]}{var.get_loss(False)}')
48
+ return var