torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,167 @@
1
+ from typing import Any, Literal
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+ from ...core import Chainable
6
+ from .quasi_newton import (
7
+ HessianUpdateStrategy,
8
+ _HessianUpdateStrategyDefaults,
9
+ _InverseHessianUpdateStrategyDefaults,
10
+ )
11
+
12
+ from ..functional import safe_clip
13
+
14
+
15
+ def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
16
+ sy = s.dot(y)
17
+ if sy < tol: return H
18
+
19
+ sy_sq = safe_clip(sy**2)
20
+
21
+ num1 = (sy + (y * H * y)) * s*s
22
+ term1 = num1.div_(sy_sq)
23
+ num2 = (H * y * s).add_(s * y * H)
24
+ term2 = num2.div_(sy)
25
+ H += term1.sub_(term2)
26
+ return H
27
+
28
+ class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
29
+ """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
30
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
31
+ return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
32
+
33
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
34
+
35
+ def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
36
+ z = s - H*y
37
+ denom = z.dot(y)
38
+
39
+ z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
40
+ y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
41
+
42
+ # if y_norm*z_norm < tol: return H
43
+
44
+ # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
45
+ if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
46
+ H += (z*z).div_(safe_clip(denom))
47
+ return H
48
+ class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
49
+ """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
50
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
51
+ return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
52
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
53
+ return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
54
+
55
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
56
+
57
+
58
+
59
+ # Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
60
+ def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
61
+ denom = safe_clip((s**4).sum())
62
+ num = s.dot(y) - (s*B).dot(s)
63
+ B += s**2 * (num/denom)
64
+ return B
65
+
66
+ class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
67
+ """Diagonal quasi-cauchi method.
68
+
69
+ Reference:
70
+ Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
71
+ """
72
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
73
+ return diagonal_qc_B_(B=B, s=s, y=y)
74
+
75
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
76
+
77
+ # Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
78
+ def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
79
+ E_sq = s**2 * B**2
80
+ denom = safe_clip((s*E_sq).dot(s))
81
+ num = s.dot(y) - (s*B).dot(s)
82
+ B += E_sq * (num/denom)
83
+ return B
84
+
85
+ class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
86
+ """Diagonal quasi-cauchi method.
87
+
88
+ Reference:
89
+ Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
90
+ """
91
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
92
+ return diagonal_wqc_B_(B=B, s=s, y=y)
93
+
94
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
95
+
96
+ def _truncate(B: torch.Tensor, lb, ub):
97
+ return torch.where((B>lb).logical_and(B<ub), B, 1)
98
+
99
+ # Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
100
+ def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
101
+ denom = safe_clip((s**4).sum())
102
+ num = s.dot(y) + s.dot(s) - (s*B).dot(s)
103
+ B += s**2 * (num/denom) - 1
104
+ return B
105
+
106
+ class DNRTR(HessianUpdateStrategy):
107
+ """Diagonal quasi-newton method.
108
+
109
+ Reference:
110
+ Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
111
+ """
112
+ def __init__(
113
+ self,
114
+ lb: float = 1e-2,
115
+ ub: float = 1e5,
116
+ init_scale: float | Literal["auto"] = "auto",
117
+ tol: float = 1e-32,
118
+ ptol: float | None = 1e-32,
119
+ ptol_restart: bool = False,
120
+ gtol: float | None = 1e-32,
121
+ restart_interval: int | None | Literal['auto'] = None,
122
+ beta: float | None = None,
123
+ update_freq: int = 1,
124
+ scale_first: bool = False,
125
+ concat_params: bool = True,
126
+ inner: Chainable | None = None,
127
+ ):
128
+ defaults = dict(lb=lb, ub=ub)
129
+ super().__init__(
130
+ defaults=defaults,
131
+ init_scale=init_scale,
132
+ tol=tol,
133
+ ptol=ptol,
134
+ ptol_restart=ptol_restart,
135
+ gtol=gtol,
136
+ restart_interval=restart_interval,
137
+ beta=beta,
138
+ update_freq=update_freq,
139
+ scale_first=scale_first,
140
+ concat_params=concat_params,
141
+ inverse=False,
142
+ inner=inner,
143
+ )
144
+
145
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
146
+ return diagonal_wqc_B_(B=B, s=s, y=y)
147
+
148
+ def modify_B(self, B, state, setting):
149
+ return _truncate(B, setting['lb'], setting['ub'])
150
+
151
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
152
+
153
+ # Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
154
+ def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
155
+ denom = safe_clip((s**4).sum())
156
+ num = s.dot(y)
157
+ B += s**2 * (num/denom)
158
+ return B
159
+
160
+ class NewDQN(DNRTR):
161
+ """Diagonal quasi-newton method.
162
+
163
+ Reference:
164
+ Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
165
+ """
166
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
167
+ return new_dqn_B_(B=B, s=s, y=y)
@@ -1,162 +1,257 @@
1
1
  from collections import deque
2
- from operator import itemgetter
2
+ from collections.abc import Sequence
3
+ from typing import overload
4
+
3
5
  import torch
4
6
 
5
- from ...core import Transform, Chainable, Module, Var, apply_transform
6
- from ...utils import TensorList, as_tensorlist, NumberList
7
+ from ...core import Chainable, Transform
8
+ from ...utils import TensorList, as_tensorlist, unpack_states
9
+ from ...utils.linalg.linear_operator import LinearOperator
10
+ from ..functional import initial_step_size
11
+ from .damping import DampingStrategyType, apply_damping
12
+
13
+
14
+ @torch.no_grad
15
+ def _make_M(S:torch.Tensor, Y:torch.Tensor, B_0:torch.Tensor):
16
+ m,n = S.size()
17
+
18
+ M = torch.zeros((2 * m, 2 * m), device=S.device, dtype=S.dtype)
19
+
20
+ # top-left is B S^T S
21
+ M[:m, :m] = B_0 * S @ S.mT
22
+
23
+ # anti-diagonal is L^T and L
24
+ L = (S @ Y.mT).tril_(-1)
25
+
26
+ M[m:, :m] = L.mT
27
+ M[:m, m:] = L
28
+
29
+ # bottom-right
30
+ D_diag = (S * Y).sum(1).neg()
31
+ M[m:, m:] = D_diag.diag_embed()
32
+
33
+ return M
34
+
35
+
36
+ @torch.no_grad
37
+ def lbfgs_Bx(x: torch.Tensor, S: torch.Tensor, Y: torch.Tensor, sy_history, M=None):
38
+ """L-BFGS hessian-vector product based on compact representation,
39
+ returns (Bx, M), where M is an internal matrix that depends on S and Y so it can be reused."""
40
+ m = len(S)
41
+ if m == 0: return x.clone()
42
+
43
+ # initial scaling
44
+ y = Y[-1]
45
+ sy = sy_history[-1]
46
+ yy = y.dot(y)
47
+ B_0 = yy / sy
48
+ Bx = x * B_0
49
+
50
+ Psi = torch.zeros(2 * m, device=x.device, dtype=x.dtype)
51
+ Psi[:m] = B_0 * S@x
52
+ Psi[m:] = Y@x
53
+
54
+ if M is None: M = _make_M(S, Y, B_0)
55
+
56
+ # solve Mu = p
57
+ u, info = torch.linalg.solve_ex(M, Psi) # pylint:disable=not-callable
58
+ if info != 0:
59
+ return Bx
60
+
61
+ # Bx
62
+ u_S = u[:m]
63
+ u_Y = u[m:]
64
+ SuS = (S * u_S.unsqueeze(-1)).sum(0)
65
+ YuY = (Y * u_Y.unsqueeze(-1)).sum(0)
66
+ return Bx - (B_0 * SuS + YuY), M
67
+
68
+
69
+ @overload
70
+ def lbfgs_Hx(
71
+ x: torch.Tensor,
72
+ s_history: Sequence[torch.Tensor] | torch.Tensor,
73
+ y_history: Sequence[torch.Tensor] | torch.Tensor,
74
+ sy_history: Sequence[torch.Tensor] | torch.Tensor,
75
+ ) -> torch.Tensor: ...
76
+ @overload
77
+ def lbfgs_Hx(
78
+ x: TensorList,
79
+ s_history: Sequence[TensorList],
80
+ y_history: Sequence[TensorList],
81
+ sy_history: Sequence[torch.Tensor] | torch.Tensor,
82
+ ) -> TensorList: ...
83
+ def lbfgs_Hx(
84
+ x,
85
+ s_history: Sequence | torch.Tensor,
86
+ y_history: Sequence | torch.Tensor,
87
+ sy_history: Sequence[torch.Tensor] | torch.Tensor,
88
+ ):
89
+ """L-BFGS inverse-hessian-vector product, works with tensors and TensorLists"""
90
+ x = x.clone()
91
+ if len(s_history) == 0: return x
92
+
93
+ # 1st loop
94
+ alpha_list = []
95
+ for s_i, y_i, sy_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
96
+ p_i = 1 / sy_i
97
+ alpha = p_i * s_i.dot(x)
98
+ alpha_list.append(alpha)
99
+ x.sub_(y_i, alpha=alpha)
100
+
101
+ # scaled initial hessian inverse
102
+ # H_0 = (s.y/y.y) * I, and z = H_0 @ q
103
+ sy = sy_history[-1]
104
+ y = y_history[-1]
105
+ Hx = x * (sy / y.dot(y))
106
+
107
+ # 2nd loop
108
+ for s_i, y_i, sy_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
109
+ p_i = 1 / sy_i
110
+ beta_i = p_i * y_i.dot(Hx)
111
+ Hx.add_(s_i, alpha = alpha_i - beta_i)
112
+
113
+ return Hx
114
+
115
+
116
+ class LBFGSLinearOperator(LinearOperator):
117
+ def __init__(self, s_history: Sequence[torch.Tensor] | torch.Tensor, y_history: Sequence[torch.Tensor] | torch.Tensor, sy_history: Sequence[torch.Tensor] | torch.Tensor):
118
+ super().__init__()
119
+ if len(s_history) == 0:
120
+ self.S = self.Y = self.yy = None
121
+ else:
122
+ self.S = s_history
123
+ self.Y = y_history
7
124
 
125
+ self.sy_history = sy_history
126
+ self.M = None
8
127
 
9
- def _adaptive_damping(
10
- s_k: TensorList,
11
- y_k: TensorList,
12
- ys_k: torch.Tensor,
13
- init_damping = 0.99,
14
- eigval_bounds = (0.01, 1.5)
15
- ):
16
- # adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
17
- sigma_l, sigma_h = eigval_bounds
18
- u = ys_k / s_k.dot(s_k)
19
- if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
20
- elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
21
- else: tau = init_damping
22
- y_k = tau * y_k + (1-tau) * s_k
23
- ys_k = s_k.dot(y_k)
24
-
25
- return s_k, y_k, ys_k
26
-
27
- def lbfgs(
28
- tensors_: TensorList,
29
- s_history: deque[TensorList],
30
- y_history: deque[TensorList],
31
- sy_history: deque[torch.Tensor],
32
- y_k: TensorList | None,
33
- ys_k: torch.Tensor | None,
34
- z_beta: float | None,
35
- z_ema: TensorList | None,
36
- step: int,
37
- ):
38
- if len(s_history) == 0 or y_k is None or ys_k is None:
39
-
40
- # initial step size guess modified from pytorch L-BFGS
41
- scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
42
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
43
- return tensors_.mul_(scale_factor)
44
-
45
- else:
46
- # 1st loop
47
- alpha_list = []
48
- q = tensors_.clone()
49
- for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
50
- p_i = 1 / ys_i # this is also denoted as ρ (rho)
51
- alpha = p_i * s_i.dot(q)
52
- alpha_list.append(alpha)
53
- q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
54
-
55
- # calculate z
56
- # s.y/y.y is also this weird y-looking symbol I couldn't find
57
- # z is it times q
58
- # actually H0 = (s.y/y.y) * I, and z = H0 @ q
59
- z = q * (ys_k / (y_k.dot(y_k)))
60
-
61
- # an attempt into adding momentum, lerping initial z seems stable compared to other variables
62
- if z_beta is not None:
63
- assert z_ema is not None
64
- if step == 0: z_ema.copy_(z)
65
- else: z_ema.lerp(z, 1-z_beta)
66
- z = z_ema
67
-
68
- # 2nd loop
69
- for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
70
- p_i = 1 / ys_i
71
- beta_i = p_i * y_i.dot(z)
72
- z.add_(s_i, alpha = alpha_i - beta_i)
73
-
74
- return z
75
-
76
- def _lerp_params_update_(
77
- self_: Module,
78
- params: list[torch.Tensor],
79
- update: list[torch.Tensor],
80
- params_beta: list[float | None],
81
- grads_beta: list[float | None],
82
- ):
83
- for i, (p, u, p_beta, u_beta) in enumerate(zip(params.copy(), update.copy(), params_beta, grads_beta)):
84
- if p_beta is not None or u_beta is not None:
85
- state = self_.state[p]
128
+ def _get_S(self):
129
+ if self.S is None: return None
130
+ if not isinstance(self.S, torch.Tensor):
131
+ self.S = torch.stack(tuple(self.S))
132
+ return self.S
86
133
 
87
- if p_beta is not None:
88
- if 'param_ema' not in state: state['param_ema'] = p.clone()
89
- else: state['param_ema'].lerp_(p, 1-p_beta)
90
- params[i] = state['param_ema']
134
+ def _get_Y(self):
135
+ if self.Y is None: return None
136
+ if not isinstance(self.Y, torch.Tensor):
137
+ self.Y = torch.stack(tuple(self.Y))
138
+ return self.Y
91
139
 
92
- if u_beta is not None:
93
- if 'grad_ema' not in state: state['grad_ema'] = u.clone()
94
- else: state['grad_ema'].lerp_(u, 1-u_beta)
95
- update[i] = state['grad_ema']
140
+ def solve(self, b):
141
+ S = self._get_S(); Y = self._get_Y()
142
+ if S is None or Y is None: return b.clone()
143
+ return lbfgs_Hx(b, S, Y, self.sy_history)
96
144
 
97
- return TensorList(params), TensorList(update)
145
+ def matvec(self, x):
146
+ S = self._get_S(); Y = self._get_Y()
147
+ if S is None or Y is None: return x.clone()
148
+ Bx, self.M = lbfgs_Bx(x, S, Y, self.sy_history, M=self.M)
149
+ return Bx
98
150
 
99
- class LBFGS(Module):
100
- """L-BFGS
151
+ def size(self):
152
+ if self.S is None: raise RuntimeError()
153
+ n = len(self.S[0])
154
+ return (n, n)
155
+
156
+
157
+ class LBFGS(Transform):
158
+ """Limited-memory BFGS algorithm. A line search or trust region is recommended.
101
159
 
102
160
  Args:
103
- history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
104
- tol (float | None, optional):
105
- tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
106
- damping (bool, optional):
107
- whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
108
- init_damping (float, optional):
109
- initial damping for adaptive dampening. Defaults to 0.9.
110
- eigval_bounds (tuple, optional):
111
- eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
112
- params_beta (float | None, optional):
113
- if not None, EMA of parameters is used for preconditioner update. Defaults to None.
114
- grads_beta (float | None, optional):
115
- if not None, EMA of gradients is used for preconditioner update. Defaults to None.
161
+ history_size (int, optional):
162
+ number of past parameter differences and gradient differences to store. Defaults to 10.
163
+ ptol (float | None, optional):
164
+ skips updating the history if maximum absolute value of
165
+ parameter difference is less than this value. Defaults to 1e-10.
166
+ ptol_restart (bool, optional):
167
+ If true, whenever parameter difference is less then ``ptol``,
168
+ L-BFGS state will be reset. Defaults to None.
169
+ gtol (float | None, optional):
170
+ skips updating the history if if maximum absolute value of
171
+ gradient difference is less than this value. Defaults to 1e-10.
172
+ ptol_restart (bool, optional):
173
+ If true, whenever gradient difference is less then ``gtol``,
174
+ L-BFGS state will be reset. Defaults to None.
175
+ sy_tol (float | None, optional):
176
+ history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)
177
+ scale_first (bool, optional):
178
+ makes first step, when hessian approximation is not available,
179
+ small to reduce number of line search iterations. Defaults to True.
116
180
  update_freq (int, optional):
117
- how often to update L-BFGS history. Defaults to 1.
118
- z_beta (float | None, optional):
119
- optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
120
- tol_reset (bool, optional):
121
- If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
181
+ how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.
182
+ damping (DampingStrategyType, optional):
183
+ damping to use, can be "powell" or "double". Defaults to None.
122
184
  inner (Chainable | None, optional):
123
185
  optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
186
+
187
+ ## Examples:
188
+
189
+ L-BFGS with line search
190
+ ```python
191
+ opt = tz.Modular(
192
+ model.parameters(),
193
+ tz.m.LBFGS(100),
194
+ tz.m.Backtracking()
195
+ )
196
+ ```
197
+
198
+ L-BFGS with trust region
199
+ ```python
200
+ opt = tz.Modular(
201
+ model.parameters(),
202
+ tz.m.TrustCG(tz.m.LBFGS())
203
+ )
204
+ ```
124
205
  """
125
206
  def __init__(
126
207
  self,
127
208
  history_size=10,
128
- tol: float | None = 1e-10,
129
- damping: bool = False,
130
- init_damping=0.9,
131
- eigval_bounds=(0.5, 50),
132
- params_beta: float | None = None,
133
- grads_beta: float | None = None,
209
+ ptol: float | None = 1e-32,
210
+ ptol_restart: bool = False,
211
+ gtol: float | None = 1e-32,
212
+ gtol_restart: bool = False,
213
+ sy_tol: float = 1e-32,
214
+ scale_first:bool=True,
134
215
  update_freq = 1,
135
- z_beta: float | None = None,
136
- tol_reset: bool = False,
216
+ damping: DampingStrategyType = None,
137
217
  inner: Chainable | None = None,
138
218
  ):
139
- defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
140
- super().__init__(defaults)
219
+ defaults = dict(
220
+ history_size=history_size,
221
+ scale_first=scale_first,
222
+ ptol=ptol,
223
+ gtol=gtol,
224
+ ptol_restart=ptol_restart,
225
+ gtol_restart=gtol_restart,
226
+ sy_tol=sy_tol,
227
+ damping = damping,
228
+ )
229
+ super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
141
230
 
142
231
  self.global_state['s_history'] = deque(maxlen=history_size)
143
232
  self.global_state['y_history'] = deque(maxlen=history_size)
144
233
  self.global_state['sy_history'] = deque(maxlen=history_size)
145
234
 
146
- if inner is not None:
147
- self.set_child('inner', inner)
148
-
149
- def reset(self):
235
+ def _reset_self(self):
150
236
  self.state.clear()
151
237
  self.global_state['step'] = 0
152
238
  self.global_state['s_history'].clear()
153
239
  self.global_state['y_history'].clear()
154
240
  self.global_state['sy_history'].clear()
155
241
 
242
+ def reset(self):
243
+ self._reset_self()
244
+ for c in self.children.values(): c.reset()
245
+
246
+ def reset_for_online(self):
247
+ super().reset_for_online()
248
+ self.clear_state_keys('p_prev', 'g_prev')
249
+ self.global_state.pop('step', None)
250
+
156
251
  @torch.no_grad
157
- def step(self, var):
158
- params = as_tensorlist(var.params)
159
- update = as_tensorlist(var.get_update())
252
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
253
+ p = as_tensorlist(params)
254
+ g = as_tensorlist(tensors)
160
255
  step = self.global_state.get('step', 0)
161
256
  self.global_state['step'] = step + 1
162
257
 
@@ -165,65 +260,83 @@ class LBFGS(Module):
165
260
  y_history: deque[TensorList] = self.global_state['y_history']
166
261
  sy_history: deque[torch.Tensor] = self.global_state['sy_history']
167
262
 
168
- tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
169
- 'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
170
- params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
263
+ ptol = self.defaults['ptol']
264
+ gtol = self.defaults['gtol']
265
+ ptol_restart = self.defaults['ptol_restart']
266
+ gtol_restart = self.defaults['gtol_restart']
267
+ sy_tol = self.defaults['sy_tol']
268
+ damping = self.defaults['damping']
171
269
 
172
- l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
173
- prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
270
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
174
271
 
175
- # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
272
+ # 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
176
273
  if step == 0:
177
- s_k = None; y_k = None; ys_k = None
274
+ s = None; y = None; sy = None
178
275
  else:
179
- s_k = l_params - prev_l_params
180
- y_k = l_update - prev_l_grad
181
- ys_k = s_k.dot(y_k)
276
+ s = p - p_prev
277
+ y = g - g_prev
278
+
279
+ if damping is not None:
280
+ s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
281
+
282
+ sy = s.dot(y)
283
+ # damping to be added here
284
+
285
+ below_tol = False
286
+ # tolerance on parameter difference to avoid exploding after converging
287
+ if ptol is not None:
288
+ if s is not None and s.abs().global_max() <= ptol:
289
+ if ptol_restart:
290
+ self._reset_self()
291
+ sy = None
292
+ below_tol = True
293
+
294
+ # tolerance on gradient difference to avoid exploding when there is no curvature
295
+ if gtol is not None:
296
+ if y is not None and y.abs().global_max() <= gtol:
297
+ if gtol_restart: self._reset_self()
298
+ sy = None
299
+ below_tol = True
300
+
301
+ # store previous params and grads
302
+ if not below_tol:
303
+ p_prev.copy_(p)
304
+ g_prev.copy_(g)
182
305
 
183
- if damping:
184
- s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
306
+ # update effective preconditioning state
307
+ if sy is not None and sy > sy_tol:
308
+ assert s is not None and y is not None and sy is not None
185
309
 
186
- prev_l_params.copy_(l_params)
187
- prev_l_grad.copy_(l_update)
310
+ s_history.append(s)
311
+ y_history.append(y)
312
+ sy_history.append(sy)
188
313
 
189
- # update effective preconditioning state
190
- if step % update_freq == 0:
191
- if ys_k is not None and ys_k > 1e-10:
192
- assert s_k is not None and y_k is not None
193
- s_history.append(s_k)
194
- y_history.append(y_k)
195
- sy_history.append(ys_k)
196
-
197
- # step with inner module before applying preconditioner
198
- if self.children:
199
- update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
200
-
201
- # tolerance on gradient difference to avoid exploding after converging
202
- if tol is not None:
203
- if y_k is not None and y_k.abs().global_max() <= tol:
204
- var.update = update # may have been updated by inner module, probably makes sense to use it here?
205
- if tol_reset: self.reset()
206
- return var
207
-
208
- # lerp initial H^-1 @ q guess
209
- z_ema = None
210
- if z_beta is not None:
211
- z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
314
+ def get_H(self, var=...):
315
+ s_history = [tl.to_vec() for tl in self.global_state['s_history']]
316
+ y_history = [tl.to_vec() for tl in self.global_state['y_history']]
317
+ sy_history = self.global_state['sy_history']
318
+ return LBFGSLinearOperator(s_history, y_history, sy_history)
319
+
320
+ @torch.no_grad
321
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
322
+ scale_first = self.defaults['scale_first']
323
+
324
+ tensors = as_tensorlist(tensors)
325
+
326
+ s_history = self.global_state['s_history']
327
+ y_history = self.global_state['y_history']
328
+ sy_history = self.global_state['sy_history']
212
329
 
213
330
  # precondition
214
- dir = lbfgs(
215
- tensors_=as_tensorlist(update),
331
+ dir = lbfgs_Hx(
332
+ x=tensors,
216
333
  s_history=s_history,
217
334
  y_history=y_history,
218
335
  sy_history=sy_history,
219
- y_k=y_k,
220
- ys_k=ys_k,
221
- z_beta = z_beta,
222
- z_ema = z_ema,
223
- step=step
224
336
  )
225
337
 
226
- var.update = dir
227
-
228
- return var
338
+ # scale 1st step
339
+ if scale_first and self.global_state.get('step', 1) == 1:
340
+ dir *= initial_step_size(dir, eps=1e-7)
229
341
 
342
+ return dir