torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,198 +1,257 @@
1
1
  from collections import deque
2
- from operator import itemgetter
2
+ from collections.abc import Sequence
3
+ from typing import overload
3
4
 
4
5
  import torch
5
6
 
6
- from ...core import Chainable, Module, Transform, Var, apply_transform
7
- from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
8
- from ..functional import safe_scaling_
9
-
10
-
11
- def _adaptive_damping(
12
- s: TensorList,
13
- y: TensorList,
14
- sy: torch.Tensor,
15
- init_damping = 0.99,
16
- eigval_bounds = (0.01, 1.5)
17
- ):
18
- # adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
19
- sigma_l, sigma_h = eigval_bounds
20
- u = sy / s.dot(s)
21
- if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
22
- elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
23
- else: tau = init_damping
24
- y = tau * y + (1-tau) * s
25
- sy = s.dot(y)
26
-
27
- return s, y, sy
28
-
29
- def lbfgs(
30
- tensors_: TensorList,
31
- s_history: deque[TensorList],
32
- y_history: deque[TensorList],
33
- sy_history: deque[torch.Tensor],
34
- y: TensorList | None,
35
- sy: torch.Tensor | None,
36
- z_beta: float | None,
37
- z_ema: TensorList | None,
38
- step: int,
7
+ from ...core import Chainable, Transform
8
+ from ...utils import TensorList, as_tensorlist, unpack_states
9
+ from ...utils.linalg.linear_operator import LinearOperator
10
+ from ..functional import initial_step_size
11
+ from .damping import DampingStrategyType, apply_damping
12
+
13
+
14
+ @torch.no_grad
15
+ def _make_M(S:torch.Tensor, Y:torch.Tensor, B_0:torch.Tensor):
16
+ m,n = S.size()
17
+
18
+ M = torch.zeros((2 * m, 2 * m), device=S.device, dtype=S.dtype)
19
+
20
+ # top-left is B S^T S
21
+ M[:m, :m] = B_0 * S @ S.mT
22
+
23
+ # anti-diagonal is L^T and L
24
+ L = (S @ Y.mT).tril_(-1)
25
+
26
+ M[m:, :m] = L.mT
27
+ M[:m, m:] = L
28
+
29
+ # bottom-right
30
+ D_diag = (S * Y).sum(1).neg()
31
+ M[m:, m:] = D_diag.diag_embed()
32
+
33
+ return M
34
+
35
+
36
+ @torch.no_grad
37
+ def lbfgs_Bx(x: torch.Tensor, S: torch.Tensor, Y: torch.Tensor, sy_history, M=None):
38
+ """L-BFGS hessian-vector product based on compact representation,
39
+ returns (Bx, M), where M is an internal matrix that depends on S and Y so it can be reused."""
40
+ m = len(S)
41
+ if m == 0: return x.clone()
42
+
43
+ # initial scaling
44
+ y = Y[-1]
45
+ sy = sy_history[-1]
46
+ yy = y.dot(y)
47
+ B_0 = yy / sy
48
+ Bx = x * B_0
49
+
50
+ Psi = torch.zeros(2 * m, device=x.device, dtype=x.dtype)
51
+ Psi[:m] = B_0 * S@x
52
+ Psi[m:] = Y@x
53
+
54
+ if M is None: M = _make_M(S, Y, B_0)
55
+
56
+ # solve Mu = p
57
+ u, info = torch.linalg.solve_ex(M, Psi) # pylint:disable=not-callable
58
+ if info != 0:
59
+ return Bx
60
+
61
+ # Bx
62
+ u_S = u[:m]
63
+ u_Y = u[m:]
64
+ SuS = (S * u_S.unsqueeze(-1)).sum(0)
65
+ YuY = (Y * u_Y.unsqueeze(-1)).sum(0)
66
+ return Bx - (B_0 * SuS + YuY), M
67
+
68
+
69
+ @overload
70
+ def lbfgs_Hx(
71
+ x: torch.Tensor,
72
+ s_history: Sequence[torch.Tensor] | torch.Tensor,
73
+ y_history: Sequence[torch.Tensor] | torch.Tensor,
74
+ sy_history: Sequence[torch.Tensor] | torch.Tensor,
75
+ ) -> torch.Tensor: ...
76
+ @overload
77
+ def lbfgs_Hx(
78
+ x: TensorList,
79
+ s_history: Sequence[TensorList],
80
+ y_history: Sequence[TensorList],
81
+ sy_history: Sequence[torch.Tensor] | torch.Tensor,
82
+ ) -> TensorList: ...
83
+ def lbfgs_Hx(
84
+ x,
85
+ s_history: Sequence | torch.Tensor,
86
+ y_history: Sequence | torch.Tensor,
87
+ sy_history: Sequence[torch.Tensor] | torch.Tensor,
39
88
  ):
40
- if len(s_history) == 0 or y is None or sy is None:
41
-
42
- # initial step size guess modified from pytorch L-BFGS
43
- return safe_scaling_(TensorList(tensors_))
89
+ """L-BFGS inverse-hessian-vector product, works with tensors and TensorLists"""
90
+ x = x.clone()
91
+ if len(s_history) == 0: return x
44
92
 
45
93
  # 1st loop
46
94
  alpha_list = []
47
- q = tensors_.clone()
48
95
  for s_i, y_i, sy_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
49
- p_i = 1 / sy_i # this is also denoted as ρ (rho)
50
- alpha = p_i * s_i.dot(q)
96
+ p_i = 1 / sy_i
97
+ alpha = p_i * s_i.dot(x)
51
98
  alpha_list.append(alpha)
52
- q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
53
-
54
- # calculate z
55
- # s.y/y.y is also this weird y-looking symbol I couldn't find
56
- # z is it times q
57
- # actually H0 = (s.y/y.y) * I, and z = H0 @ q
58
- z = q * (sy / (y.dot(y)))
99
+ x.sub_(y_i, alpha=alpha)
59
100
 
60
- # an attempt into adding momentum, lerping initial z seems stable compared to other variables
61
- if z_beta is not None:
62
- assert z_ema is not None
63
- if step == 1: z_ema.copy_(z)
64
- else: z_ema.lerp(z, 1-z_beta)
65
- z = z_ema
101
+ # scaled initial hessian inverse
102
+ # H_0 = (s.y/y.y) * I, and z = H_0 @ q
103
+ sy = sy_history[-1]
104
+ y = y_history[-1]
105
+ Hx = x * (sy / y.dot(y))
66
106
 
67
107
  # 2nd loop
68
108
  for s_i, y_i, sy_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
69
109
  p_i = 1 / sy_i
70
- beta_i = p_i * y_i.dot(z)
71
- z.add_(s_i, alpha = alpha_i - beta_i)
110
+ beta_i = p_i * y_i.dot(Hx)
111
+ Hx.add_(s_i, alpha = alpha_i - beta_i)
72
112
 
73
- return z
113
+ return Hx
74
114
 
75
- def _lerp_params_update_(
76
- self_: Module,
77
- params: list[torch.Tensor],
78
- update: list[torch.Tensor],
79
- params_beta: list[float | None],
80
- grads_beta: list[float | None],
81
- ):
82
- for i, (p, u, p_beta, u_beta) in enumerate(zip(params.copy(), update.copy(), params_beta, grads_beta)):
83
- if p_beta is not None or u_beta is not None:
84
- state = self_.state[p]
85
115
 
86
- if p_beta is not None:
87
- if 'param_ema' not in state: state['param_ema'] = p.clone()
88
- else: state['param_ema'].lerp_(p, 1-p_beta)
89
- params[i] = state['param_ema']
116
+ class LBFGSLinearOperator(LinearOperator):
117
+ def __init__(self, s_history: Sequence[torch.Tensor] | torch.Tensor, y_history: Sequence[torch.Tensor] | torch.Tensor, sy_history: Sequence[torch.Tensor] | torch.Tensor):
118
+ super().__init__()
119
+ if len(s_history) == 0:
120
+ self.S = self.Y = self.yy = None
121
+ else:
122
+ self.S = s_history
123
+ self.Y = y_history
124
+
125
+ self.sy_history = sy_history
126
+ self.M = None
127
+
128
+ def _get_S(self):
129
+ if self.S is None: return None
130
+ if not isinstance(self.S, torch.Tensor):
131
+ self.S = torch.stack(tuple(self.S))
132
+ return self.S
133
+
134
+ def _get_Y(self):
135
+ if self.Y is None: return None
136
+ if not isinstance(self.Y, torch.Tensor):
137
+ self.Y = torch.stack(tuple(self.Y))
138
+ return self.Y
90
139
 
91
- if u_beta is not None:
92
- if 'grad_ema' not in state: state['grad_ema'] = u.clone()
93
- else: state['grad_ema'].lerp_(u, 1-u_beta)
94
- update[i] = state['grad_ema']
140
+ def solve(self, b):
141
+ S = self._get_S(); Y = self._get_Y()
142
+ if S is None or Y is None: return b.clone()
143
+ return lbfgs_Hx(b, S, Y, self.sy_history)
144
+
145
+ def matvec(self, x):
146
+ S = self._get_S(); Y = self._get_Y()
147
+ if S is None or Y is None: return x.clone()
148
+ Bx, self.M = lbfgs_Bx(x, S, Y, self.sy_history, M=self.M)
149
+ return Bx
150
+
151
+ def size(self):
152
+ if self.S is None: raise RuntimeError()
153
+ n = len(self.S[0])
154
+ return (n, n)
95
155
 
96
- return TensorList(params), TensorList(update)
97
156
 
98
157
  class LBFGS(Transform):
99
- """Limited-memory BFGS algorithm. A line search is recommended, although L-BFGS may be reasonably stable without it.
158
+ """Limited-memory BFGS algorithm. A line search or trust region is recommended.
100
159
 
101
160
  Args:
102
161
  history_size (int, optional):
103
162
  number of past parameter differences and gradient differences to store. Defaults to 10.
104
- damping (bool, optional):
105
- whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
106
- init_damping (float, optional):
107
- initial damping for adaptive dampening. Defaults to 0.9.
108
- eigval_bounds (tuple, optional):
109
- eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
110
- tol (float | None, optional):
111
- tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
112
- tol_reset (bool, optional):
113
- If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
163
+ ptol (float | None, optional):
164
+ skips updating the history if maximum absolute value of
165
+ parameter difference is less than this value. Defaults to 1e-10.
166
+ ptol_restart (bool, optional):
167
+ If true, whenever parameter difference is less then ``ptol``,
168
+ L-BFGS state will be reset. Defaults to None.
114
169
  gtol (float | None, optional):
115
- tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
116
- params_beta (float | None, optional):
117
- if not None, EMA of parameters is used for preconditioner update. Defaults to None.
118
- grads_beta (float | None, optional):
119
- if not None, EMA of gradients is used for preconditioner update. Defaults to None.
170
+ skips updating the history if if maximum absolute value of
171
+ gradient difference is less than this value. Defaults to 1e-10.
172
+ ptol_restart (bool, optional):
173
+ If true, whenever gradient difference is less then ``gtol``,
174
+ L-BFGS state will be reset. Defaults to None.
175
+ sy_tol (float | None, optional):
176
+ history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)
177
+ scale_first (bool, optional):
178
+ makes first step, when hessian approximation is not available,
179
+ small to reduce number of line search iterations. Defaults to True.
120
180
  update_freq (int, optional):
121
- how often to update L-BFGS history. Defaults to 1.
122
- z_beta (float | None, optional):
123
- optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
181
+ how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.
182
+ damping (DampingStrategyType, optional):
183
+ damping to use, can be "powell" or "double". Defaults to None.
124
184
  inner (Chainable | None, optional):
125
185
  optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
126
186
 
127
- Examples:
128
- L-BFGS with strong-wolfe line search
129
-
130
- .. code-block:: python
131
-
132
- opt = tz.Modular(
133
- model.parameters(),
134
- tz.m.LBFGS(100),
135
- tz.m.StrongWolfe()
136
- )
137
-
138
- Dampened L-BFGS
139
-
140
- .. code-block:: python
141
-
142
- opt = tz.Modular(
143
- model.parameters(),
144
- tz.m.LBFGS(damping=True),
145
- tz.m.StrongWolfe()
146
- )
147
-
148
- L-BFGS preconditioning applied to momentum (may be unstable!)
149
-
150
- .. code-block:: python
151
-
152
- opt = tz.Modular(
153
- model.parameters(),
154
- tz.m.LBFGS(inner=tz.m.EMA(0.9)),
155
- tz.m.LR(1e-2)
156
- )
187
+ ## Examples:
188
+
189
+ L-BFGS with line search
190
+ ```python
191
+ opt = tz.Modular(
192
+ model.parameters(),
193
+ tz.m.LBFGS(100),
194
+ tz.m.Backtracking()
195
+ )
196
+ ```
197
+
198
+ L-BFGS with trust region
199
+ ```python
200
+ opt = tz.Modular(
201
+ model.parameters(),
202
+ tz.m.TrustCG(tz.m.LBFGS())
203
+ )
204
+ ```
157
205
  """
158
206
  def __init__(
159
207
  self,
160
208
  history_size=10,
161
- damping: bool = False,
162
- init_damping=0.9,
163
- eigval_bounds=(0.5, 50),
164
- tol: float | None = 1e-10,
165
- tol_reset: bool = False,
166
- gtol: float | None = 1e-10,
167
- params_beta: float | None = None,
168
- grads_beta: float | None = None,
209
+ ptol: float | None = 1e-32,
210
+ ptol_restart: bool = False,
211
+ gtol: float | None = 1e-32,
212
+ gtol_restart: bool = False,
213
+ sy_tol: float = 1e-32,
214
+ scale_first:bool=True,
169
215
  update_freq = 1,
170
- z_beta: float | None = None,
216
+ damping: DampingStrategyType = None,
171
217
  inner: Chainable | None = None,
172
218
  ):
173
- defaults = dict(history_size=history_size, tol=tol, gtol=gtol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
174
- super().__init__(defaults, uses_grad=False, inner=inner)
219
+ defaults = dict(
220
+ history_size=history_size,
221
+ scale_first=scale_first,
222
+ ptol=ptol,
223
+ gtol=gtol,
224
+ ptol_restart=ptol_restart,
225
+ gtol_restart=gtol_restart,
226
+ sy_tol=sy_tol,
227
+ damping = damping,
228
+ )
229
+ super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
175
230
 
176
231
  self.global_state['s_history'] = deque(maxlen=history_size)
177
232
  self.global_state['y_history'] = deque(maxlen=history_size)
178
233
  self.global_state['sy_history'] = deque(maxlen=history_size)
179
234
 
180
- def reset(self):
235
+ def _reset_self(self):
181
236
  self.state.clear()
182
237
  self.global_state['step'] = 0
183
238
  self.global_state['s_history'].clear()
184
239
  self.global_state['y_history'].clear()
185
240
  self.global_state['sy_history'].clear()
186
241
 
242
+ def reset(self):
243
+ self._reset_self()
244
+ for c in self.children.values(): c.reset()
245
+
187
246
  def reset_for_online(self):
188
247
  super().reset_for_online()
189
- self.clear_state_keys('prev_l_params', 'prev_l_grad')
248
+ self.clear_state_keys('p_prev', 'g_prev')
190
249
  self.global_state.pop('step', None)
191
250
 
192
251
  @torch.no_grad
193
252
  def update_tensors(self, tensors, params, grads, loss, states, settings):
194
- params = as_tensorlist(params)
195
- update = as_tensorlist(tensors)
253
+ p = as_tensorlist(params)
254
+ g = as_tensorlist(tensors)
196
255
  step = self.global_state.get('step', 0)
197
256
  self.global_state['step'] = step + 1
198
257
 
@@ -201,86 +260,83 @@ class LBFGS(Transform):
201
260
  y_history: deque[TensorList] = self.global_state['y_history']
202
261
  sy_history: deque[torch.Tensor] = self.global_state['sy_history']
203
262
 
204
- damping,init_damping,eigval_bounds,update_freq = itemgetter('damping','init_damping','eigval_bounds','update_freq')(settings[0])
205
- params_beta, grads_beta = unpack_dicts(settings, 'params_beta', 'grads_beta')
263
+ ptol = self.defaults['ptol']
264
+ gtol = self.defaults['gtol']
265
+ ptol_restart = self.defaults['ptol_restart']
266
+ gtol_restart = self.defaults['gtol_restart']
267
+ sy_tol = self.defaults['sy_tol']
268
+ damping = self.defaults['damping']
206
269
 
207
- l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
208
- prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
270
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
209
271
 
210
272
  # 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
211
273
  if step == 0:
212
274
  s = None; y = None; sy = None
213
275
  else:
214
- s = l_params - prev_l_params
215
- y = l_update - prev_l_grad
276
+ s = p - p_prev
277
+ y = g - g_prev
278
+
279
+ if damping is not None:
280
+ s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
281
+
216
282
  sy = s.dot(y)
283
+ # damping to be added here
217
284
 
218
- if damping:
219
- s, y, sy = _adaptive_damping(s, y, sy, init_damping=init_damping, eigval_bounds=eigval_bounds)
285
+ below_tol = False
286
+ # tolerance on parameter difference to avoid exploding after converging
287
+ if ptol is not None:
288
+ if s is not None and s.abs().global_max() <= ptol:
289
+ if ptol_restart:
290
+ self._reset_self()
291
+ sy = None
292
+ below_tol = True
220
293
 
221
- prev_l_params.copy_(l_params)
222
- prev_l_grad.copy_(l_update)
294
+ # tolerance on gradient difference to avoid exploding when there is no curvature
295
+ if gtol is not None:
296
+ if y is not None and y.abs().global_max() <= gtol:
297
+ if gtol_restart: self._reset_self()
298
+ sy = None
299
+ below_tol = True
223
300
 
224
- # update effective preconditioning state
225
- if step % update_freq == 0:
226
- if sy is not None and sy > 1e-10:
227
- assert s is not None and y is not None
228
- s_history.append(s)
229
- y_history.append(y)
230
- sy_history.append(sy)
301
+ # store previous params and grads
302
+ if not below_tol:
303
+ p_prev.copy_(p)
304
+ g_prev.copy_(g)
231
305
 
232
- # store for apply
233
- self.global_state['s'] = s
234
- self.global_state['y'] = y
235
- self.global_state['sy'] = sy
306
+ # update effective preconditioning state
307
+ if sy is not None and sy > sy_tol:
308
+ assert s is not None and y is not None and sy is not None
236
309
 
237
- def make_Hv(self):
238
- ...
310
+ s_history.append(s)
311
+ y_history.append(y)
312
+ sy_history.append(sy)
239
313
 
240
- def make_Bv(self):
241
- ...
314
+ def get_H(self, var=...):
315
+ s_history = [tl.to_vec() for tl in self.global_state['s_history']]
316
+ y_history = [tl.to_vec() for tl in self.global_state['y_history']]
317
+ sy_history = self.global_state['sy_history']
318
+ return LBFGSLinearOperator(s_history, y_history, sy_history)
242
319
 
243
320
  @torch.no_grad
244
321
  def apply_tensors(self, tensors, params, grads, loss, states, settings):
245
- tensors = as_tensorlist(tensors)
246
-
247
- s = self.global_state.pop('s')
248
- y = self.global_state.pop('y')
249
- sy = self.global_state.pop('sy')
250
-
251
- setting = settings[0]
252
- tol = setting['tol']
253
- gtol = setting['gtol']
254
- tol_reset = setting['tol_reset']
255
- z_beta = setting['z_beta']
256
-
257
- # tolerance on parameter difference to avoid exploding after converging
258
- if tol is not None:
259
- if s is not None and s.abs().global_max() <= tol:
260
- if tol_reset: self.reset()
261
- return safe_scaling_(TensorList(tensors))
322
+ scale_first = self.defaults['scale_first']
262
323
 
263
- # tolerance on gradient difference to avoid exploding when there is no curvature
264
- if tol is not None:
265
- if y is not None and y.abs().global_max() <= gtol:
266
- return safe_scaling_(TensorList(tensors))
324
+ tensors = as_tensorlist(tensors)
267
325
 
268
- # lerp initial H^-1 @ q guess
269
- z_ema = None
270
- if z_beta is not None:
271
- z_ema = unpack_states(states, tensors, 'z_ema', cls=TensorList)
326
+ s_history = self.global_state['s_history']
327
+ y_history = self.global_state['y_history']
328
+ sy_history = self.global_state['sy_history']
272
329
 
273
330
  # precondition
274
- dir = lbfgs(
275
- tensors_=tensors,
276
- s_history=self.global_state['s_history'],
277
- y_history=self.global_state['y_history'],
278
- sy_history=self.global_state['sy_history'],
279
- y=y,
280
- sy=sy,
281
- z_beta = z_beta,
282
- z_ema = z_ema,
283
- step=self.global_state.get('step', 1)
331
+ dir = lbfgs_Hx(
332
+ x=tensors,
333
+ s_history=s_history,
334
+ y_history=y_history,
335
+ sy_history=sy_history,
284
336
  )
285
337
 
338
+ # scale 1st step
339
+ if scale_first and self.global_state.get('step', 1) == 1:
340
+ dir *= initial_step_size(dir, eps=1e-7)
341
+
286
342
  return dir