torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,15 @@
1
1
  """Use BFGS or maybe SR1."""
2
- from typing import Any, Literal
3
2
  from abc import ABC, abstractmethod
4
- from collections.abc import Mapping
3
+ from collections.abc import Mapping, Callable
4
+ from typing import Any, Literal
5
+ import warnings
6
+
5
7
  import torch
6
8
 
7
- from ...core import Chainable, Module, Preconditioner, TensorwisePreconditioner
8
- from ...utils import TensorList, set_storage_
9
+ from ...core import Chainable, Module, TensorwiseTransform, Transform
10
+ from ...utils import TensorList, set_storage_, unpack_states
11
+ from ..functional import safe_scaling_
12
+
9
13
 
10
14
  def _safe_dict_update_(d1_:dict, d2:dict):
11
15
  inter = set(d1_.keys()).intersection(d2.keys())
@@ -17,14 +21,112 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
17
21
  elif state[key].shape != value.shape: state[key] = value
18
22
  else: state[key].lerp_(value, 1-beta)
19
23
 
20
- class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
24
+ def _safe_clip(x: torch.Tensor):
25
+ """makes sure scalar tensor x is not smaller than epsilon"""
26
+ assert x.numel() == 1, x.shape
27
+ eps = torch.finfo(x.dtype).eps ** 2
28
+ if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
29
+ return x
30
+
31
+ class HessianUpdateStrategy(TensorwiseTransform, ABC):
32
+ """Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
33
+
34
+ This is an abstract class, to use it, subclass it and override `update_H` and/or `update_B`.
35
+
36
+ Args:
37
+ defaults (dict | None, optional): defaults. Defaults to None.
38
+ init_scale (float | Literal["auto"], optional):
39
+ initial hessian matrix is set to identity times this.
40
+
41
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
42
+
43
+ Defaults to "auto".
44
+ tol (float, optional):
45
+ algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-8.
46
+ ptol (float | None, optional):
47
+ tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
48
+ ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
49
+ gtol (float | None, optional):
50
+ tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
51
+ reset_interval (int | None | Literal["auto"], optional):
52
+ interval between resetting the hessian approximation.
53
+
54
+ "auto" corresponds to number of decision variables + 1.
55
+
56
+ None - no resets.
57
+
58
+ Defaults to None.
59
+ beta (float | None, optional): momentum on H or B. Defaults to None.
60
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
61
+ scale_first (bool, optional):
62
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
63
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
64
+ concat_params (bool, optional):
65
+ If true, all parameters are treated as a single vector.
66
+ If False, the update rule is applied to each parameter separately. Defaults to True.
67
+ inverse (bool, optional):
68
+ set to True if this method uses hessian inverse approximation H and has `update_H` method.
69
+ set to False if this maintains hessian approximation B and has `update_B method`.
70
+ Defaults to True.
71
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
72
+
73
+ Example:
74
+ Implementing BFGS method that maintains an estimate of the hessian inverse (H):
75
+
76
+ .. code-block:: python
77
+
78
+ class BFGS(HessianUpdateStrategy):
79
+ def __init__(
80
+ self,
81
+ init_scale: float | Literal["auto"] = "auto",
82
+ tol: float = 1e-8,
83
+ ptol: float = 1e-10,
84
+ ptol_reset: bool = False,
85
+ reset_interval: int | None = None,
86
+ beta: float | None = None,
87
+ update_freq: int = 1,
88
+ scale_first: bool = True,
89
+ scale_second: bool = False,
90
+ concat_params: bool = True,
91
+ inner: Chainable | None = None,
92
+ ):
93
+ super().__init__(
94
+ defaults=None,
95
+ init_scale=init_scale,
96
+ tol=tol,
97
+ ptol=ptol,
98
+ ptol_reset=ptol_reset,
99
+ reset_interval=reset_interval,
100
+ beta=beta,
101
+ update_freq=update_freq,
102
+ scale_first=scale_first,
103
+ scale_second=scale_second,
104
+ concat_params=concat_params,
105
+ inverse=True,
106
+ inner=inner,
107
+ )
108
+
109
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
110
+ tol = settings["tol"]
111
+ sy = torch.dot(s, y)
112
+ if sy <= tol: return H
113
+ num1 = (sy + (y @ H @ y)) * s.outer(s)
114
+ term1 = num1.div_(sy**2)
115
+ num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
116
+ term2 = num2.div_(sy)
117
+ H += term1.sub_(term2)
118
+ return H
119
+
120
+ """
21
121
  def __init__(
22
122
  self,
23
123
  defaults: dict | None = None,
24
124
  init_scale: float | Literal["auto"] = "auto",
25
- tol: float = 1e-10,
26
- tol_reset: bool = True,
27
- reset_interval: int | None = None,
125
+ tol: float = 1e-8,
126
+ ptol: float | None = 1e-10,
127
+ ptol_reset: bool = False,
128
+ gtol: float | None = 1e-10,
129
+ reset_interval: int | None | Literal['auto'] = None,
28
130
  beta: float | None = None,
29
131
  update_freq: int = 1,
30
132
  scale_first: bool = True,
@@ -34,9 +136,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
34
136
  inner: Chainable | None = None,
35
137
  ):
36
138
  if defaults is None: defaults = {}
37
- _safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, tol_reset=tol_reset, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
139
+ _safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_reset=ptol_reset, gtol=gtol, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
38
140
  super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
39
141
 
142
+ def _init_M(self, size:int, device, dtype, is_inverse:bool):
143
+ return torch.eye(size, device=device, dtype=dtype)
144
+
40
145
  def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
41
146
  """returns multiplier to H or B"""
42
147
  ys = y.dot(s)
@@ -44,47 +149,92 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
44
149
  if ys != 0 and yy != 0: return yy/ys
45
150
  return 1
46
151
 
47
- def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
48
- set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
152
+ def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
153
+ set_storage_(M, self._init_M(s.numel(), device=M.device, dtype=M.dtype, is_inverse=inverse))
49
154
  if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
50
155
  if init_scale >= 1:
51
156
  if inverse: M /= init_scale
52
157
  else: M *= init_scale
53
158
 
54
159
  def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
55
- p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
160
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
56
161
  """update hessian inverse"""
57
162
  raise NotImplementedError
58
163
 
59
164
  def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
60
- p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
165
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
61
166
  """update hessian"""
62
167
  raise NotImplementedError
63
168
 
169
+ def reset_for_online(self):
170
+ super().reset_for_online()
171
+ self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
172
+
173
+ def get_B(self) -> tuple[torch.Tensor, bool]:
174
+ """returns (B or H, is_inverse)."""
175
+ state = next(iter(self.state.values()))
176
+ if "B" in state: return state["B"], False
177
+ return state["H"], True
178
+
179
+ def get_H(self) -> tuple[torch.Tensor, bool]:
180
+ """returns (H or B, is_inverse)."""
181
+ state = next(iter(self.state.values()))
182
+ if "H" in state: return state["H"], False
183
+ return state["B"], True
184
+
185
+ def make_Bv(self) -> Callable[[torch.Tensor], torch.Tensor]:
186
+ B, is_inverse = self.get_B()
187
+
188
+ if is_inverse:
189
+ H=B
190
+ warnings.warn(f'{self} maintains H, so Bv will be inefficient!')
191
+ def Hxv(v): return torch.linalg.solve_ex(H, v)[0] # pylint:disable=not-callable
192
+ return Hxv
193
+
194
+ def Bv(v): return B@v
195
+ return Bv
196
+
197
+ def make_Hv(self) -> Callable[[torch.Tensor], torch.Tensor]:
198
+ H, is_inverse = self.get_H()
199
+
200
+ if is_inverse:
201
+ B=H
202
+ warnings.warn(f'{self} maintains B, so Hv will be inefficient!')
203
+ def Bxv(v): return torch.linalg.solve_ex(B, v)[0] # pylint:disable=not-callable
204
+ return Bxv
205
+
206
+ def Hv(v): return H@v
207
+ return Hv
208
+
64
209
  @torch.no_grad
65
- def update_tensor(self, tensor, param, grad, state, settings):
210
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
66
211
  p = param.view(-1); g = tensor.view(-1)
67
- inverse = settings['inverse']
212
+ inverse = setting['inverse']
68
213
  M_key = 'H' if inverse else 'B'
69
214
  M = state.get(M_key, None)
70
- step = state.get('step', 0)
71
- state['step'] = step + 1
72
- init_scale = settings['init_scale']
73
- tol = settings['tol']
74
- tol_reset = settings['tol_reset']
75
- reset_interval = settings['reset_interval']
76
-
77
- if M is None:
78
- M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
79
- if isinstance(init_scale, (int, float)) and init_scale != 1:
80
- if inverse: M /= init_scale
81
- else: M *= init_scale
215
+ step = state.get('step', 0) + 1
216
+ state['step'] = step
217
+ init_scale = setting['init_scale']
218
+ ptol = setting['ptol']
219
+ ptol_reset = setting['ptol_reset']
220
+ gtol = setting['gtol']
221
+ reset_interval = setting['reset_interval']
222
+ if reset_interval == 'auto': reset_interval = tensor.numel() + 1
223
+
224
+ if M is None or 'f_prev' not in state:
225
+ if M is None: # won't be true on reset_for_online
226
+ M = self._init_M(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
227
+ if isinstance(init_scale, (int, float)) and init_scale != 1:
228
+ if inverse: M /= init_scale
229
+ else: M *= init_scale
82
230
 
83
231
  state[M_key] = M
232
+ state['f_prev'] = loss
84
233
  state['p_prev'] = p.clone()
85
234
  state['g_prev'] = g.clone()
86
235
  return
87
236
 
237
+ state['f'] = loss
88
238
  p_prev = state['p_prev']
89
239
  g_prev = state['g_prev']
90
240
  s: torch.Tensor = p - p_prev
@@ -92,195 +242,511 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
92
242
  state['p_prev'].copy_(p)
93
243
  state['g_prev'].copy_(g)
94
244
 
95
- if reset_interval is not None and step != 0 and step % reset_interval == 0:
96
- self._reset_M_(M, s, y, inverse, init_scale)
245
+ if reset_interval is not None and step % reset_interval == 0:
246
+ self._reset_M_(M, s, y, inverse, init_scale, state)
247
+ return
248
+
249
+ # tolerance on parameter difference to avoid exploding after converging
250
+ if ptol is not None and s.abs().max() <= ptol:
251
+ if ptol_reset: self._reset_M_(M, s, y, inverse, init_scale, state) # reset history
97
252
  return
98
253
 
99
- # tolerance on gradient difference to avoid exploding after converging
100
- elif y.abs().max() <= tol:
101
- # reset history
102
- if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
254
+ # tolerance on gradient difference to avoid exploding when there is no curvature
255
+ if gtol is not None and y.abs().max() <= gtol:
103
256
  return
104
257
 
105
- if step == 1 and init_scale == 'auto':
258
+ if step == 2 and init_scale == 'auto':
106
259
  if inverse: M /= self._get_init_scale(s,y)
107
260
  else: M *= self._get_init_scale(s,y)
108
261
 
109
- beta = settings['beta']
262
+ beta = setting['beta']
110
263
  if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
111
264
 
112
265
  if inverse:
113
- H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
266
+ H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
114
267
  _maybe_lerp_(state, 'H', H_new, beta)
115
268
 
116
269
  else:
117
- B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
270
+ B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
118
271
  _maybe_lerp_(state, 'B', B_new, beta)
119
272
 
273
+ state['f_prev'] = loss
274
+
275
+ def _post_B(self, B: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
276
+ """modifies B before appling the update rule. Must return (B, g)"""
277
+ return B, g
278
+
279
+ def _post_H(self, H: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
280
+ """modifies H before appling the update rule. Must return (H, g)"""
281
+ return H, g
282
+
120
283
  @torch.no_grad
121
- def apply_tensor(self, tensor, param, grad, state, settings):
284
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
122
285
  step = state.get('step', 0)
123
286
 
124
- if settings['scale_second'] and step == 2:
125
- scale_factor = 1 / tensor.abs().sum().clip(min=1)
126
- scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
127
- tensor = tensor * scale_factor
287
+ if setting['scale_second'] and step == 2:
288
+ tensor = safe_scaling_(tensor)
128
289
 
129
- inverse = settings['inverse']
290
+ inverse = setting['inverse']
130
291
  if inverse:
131
292
  H = state['H']
132
- return (H @ tensor.view(-1)).view_as(tensor)
293
+ H, g = self._post_H(H, tensor.view(-1), state, setting)
294
+ if H.ndim == 1: return g.mul_(H).view_as(tensor)
295
+ return (H @ g).view_as(tensor)
133
296
 
134
297
  B = state['B']
298
+ H, g = self._post_B(B, tensor.view(-1), state, setting)
299
+
300
+ if B.ndim == 1: return g.div_(B).view_as(tensor)
301
+ x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
302
+ if info == 0: return x.view_as(tensor)
303
+ return safe_scaling_(tensor)
304
+
305
+ class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
306
+ '''This is :code:`HessianUpdateStrategy` subclass for algorithms with no extra defaults, to skip the lengthy __init__.
307
+ Refer to :code:`HessianUpdateStrategy` documentation.
308
+
309
+ Example:
310
+ Implementing BFGS method that maintains an estimate of the hessian inverse (H):
311
+
312
+ .. code-block:: python
313
+
314
+ class BFGS(_HessianUpdateStrategyDefaults):
315
+ """Broyden–Fletcher–Goldfarb–Shanno algorithm"""
316
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
317
+ tol = settings["tol"]
318
+ sy = torch.dot(s, y)
319
+ if sy <= tol: return H
320
+ num1 = (sy + (y @ H @ y)) * s.outer(s)
321
+ term1 = num1.div_(sy**2)
322
+ num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
323
+ term2 = num2.div_(sy)
324
+ H += term1.sub_(term2)
325
+ return H
326
+
327
+ Make sure to put at least a basic class level docstring to overwrite this.
328
+ '''
329
+ def __init__(
330
+ self,
331
+ init_scale: float | Literal["auto"] = "auto",
332
+ tol: float = 1e-8,
333
+ ptol: float | None = 1e-10,
334
+ ptol_reset: bool = False,
335
+ gtol: float | None = 1e-10,
336
+ reset_interval: int | None = None,
337
+ beta: float | None = None,
338
+ update_freq: int = 1,
339
+ scale_first: bool = True,
340
+ scale_second: bool = False,
341
+ concat_params: bool = True,
342
+ inverse: bool = True,
343
+ inner: Chainable | None = None,
344
+ ):
345
+ super().__init__(
346
+ defaults=None,
347
+ init_scale=init_scale,
348
+ tol=tol,
349
+ ptol=ptol,
350
+ ptol_reset=ptol_reset,
351
+ gtol=gtol,
352
+ reset_interval=reset_interval,
353
+ beta=beta,
354
+ update_freq=update_freq,
355
+ scale_first=scale_first,
356
+ scale_second=scale_second,
357
+ concat_params=concat_params,
358
+ inverse=inverse,
359
+ inner=inner,
360
+ )
135
361
 
136
- return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
137
-
138
- # to avoid typing all arguments for each method
139
- class HUpdateStrategy(HessianUpdateStrategy):
362
+ class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
140
363
  def __init__(
141
364
  self,
142
365
  init_scale: float | Literal["auto"] = "auto",
143
- tol: float = 1e-10,
144
- tol_reset: bool = True,
366
+ tol: float = 1e-8,
367
+ ptol: float | None = 1e-10,
368
+ ptol_reset: bool = False,
369
+ gtol: float | None = 1e-10,
145
370
  reset_interval: int | None = None,
146
371
  beta: float | None = None,
147
372
  update_freq: int = 1,
148
373
  scale_first: bool = True,
149
374
  scale_second: bool = False,
150
375
  concat_params: bool = True,
376
+ inverse: bool = False,
151
377
  inner: Chainable | None = None,
152
378
  ):
153
379
  super().__init__(
154
380
  defaults=None,
155
381
  init_scale=init_scale,
156
382
  tol=tol,
157
- tol_reset=tol_reset,
383
+ ptol=ptol,
384
+ ptol_reset=ptol_reset,
385
+ gtol=gtol,
158
386
  reset_interval=reset_interval,
159
387
  beta=beta,
160
388
  update_freq=update_freq,
161
389
  scale_first=scale_first,
162
390
  scale_second=scale_second,
163
391
  concat_params=concat_params,
164
- inverse=True,
392
+ inverse=inverse,
165
393
  inner=inner,
166
394
  )
395
+
167
396
  # ----------------------------------- BFGS ----------------------------------- #
397
+ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
398
+ sy = s.dot(y)
399
+ if sy < tol: return B
400
+
401
+ Bs = B@s
402
+ sBs = _safe_clip(s.dot(Bs))
403
+
404
+ term1 = y.outer(y).div_(sy)
405
+ term2 = (Bs.outer(s) @ B.T).div_(sBs)
406
+ B += term1.sub_(term2)
407
+ return B
408
+
168
409
  def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
169
- sy = torch.dot(s, y)
170
- if sy <= tol: return H # don't reset H in this case
171
- num1 = (sy + (y @ H @ y)) * s.outer(s)
172
- term1 = num1.div_(sy**2)
173
- num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
410
+ sy = s.dot(y)
411
+ if sy <= tol: return H
412
+
413
+ sy_sq = _safe_clip(sy**2)
414
+
415
+ Hy = H@y
416
+ scale1 = (sy + y.dot(Hy)) / sy_sq
417
+ term1 = s.outer(s).mul_(scale1)
418
+
419
+ num2 = (Hy.outer(s)).add_(s.outer(y @ H))
174
420
  term2 = num2.div_(sy)
421
+
175
422
  H += term1.sub_(term2)
176
423
  return H
177
424
 
178
- class BFGS(HUpdateStrategy):
179
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
180
- return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
425
+ class BFGS(_InverseHessianUpdateStrategyDefaults):
426
+ """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
427
+
428
+ .. note::
429
+ a line search such as :code:`tz.m.StrongWolfe()` is recommended, although this can be stable without a line search. Alternatively warmup :code:`tz.m.Warmup` can stabilize quasi-newton methods without line search.
430
+
431
+ .. warning::
432
+ this uses roughly O(N^2) memory.
433
+
434
+ Args:
435
+ init_scale (float | Literal["auto"], optional):
436
+ initial hessian matrix is set to identity times this.
437
+
438
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
439
+
440
+ Defaults to "auto".
441
+ tol (float, optional):
442
+ tolerance on curvature condition. Defaults to 1e-8.
443
+ ptol (float | None, optional):
444
+ skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
445
+ Defaults to 1e-10.
446
+ ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
447
+ reset_interval (int | None | Literal["auto"], optional):
448
+ interval between resetting the hessian approximation.
449
+
450
+ "auto" corresponds to number of decision variables + 1.
451
+
452
+ None - no resets.
453
+
454
+ Defaults to None.
455
+ beta (float | None, optional): momentum on H or B. Defaults to None.
456
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
457
+ scale_first (bool, optional):
458
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
459
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
460
+ concat_params (bool, optional):
461
+ If true, all parameters are treated as a single vector.
462
+ If False, the update rule is applied to each parameter separately. Defaults to True.
463
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
464
+
465
+ Examples:
466
+ BFGS with strong-wolfe line search:
467
+
468
+ .. code-block:: python
469
+
470
+ opt = tz.Modular(
471
+ model.parameters(),
472
+ tz.m.BFGS(),
473
+ tz.m.StrongWolfe()
474
+ )
475
+
476
+ BFGS preconditioning applied to momentum:
477
+
478
+ .. code-block:: python
479
+
480
+ opt = tz.Modular(
481
+ model.parameters(),
482
+ tz.m.BFGS(inner=tz.m.EMA(0.9)),
483
+ tz.m.LR(1e-2)
484
+ )
485
+ """
486
+
487
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
488
+ return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
489
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
490
+ return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])
181
491
 
182
492
  # ------------------------------------ SR1 ----------------------------------- #
183
- def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
493
+ def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
184
494
  z = s - H@y
185
- denom = torch.dot(z, y)
495
+ denom = z.dot(y)
186
496
 
187
497
  z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
188
498
  y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
189
499
 
190
- if y_norm*z_norm < tol: return H
500
+ # if y_norm*z_norm < tol: return H
191
501
 
192
502
  # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
193
503
  if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
194
- H += torch.outer(z, z).div_(denom)
504
+ H += z.outer(z).div_(_safe_clip(denom))
195
505
  return H
196
506
 
197
- class SR1(HUpdateStrategy):
198
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
199
- return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
507
+ class SR1(_InverseHessianUpdateStrategyDefaults):
508
+ """Symmetric Rank 1 Quasi-Newton method.
509
+
510
+ .. note::
511
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
512
+
513
+ .. note::
514
+ approximate Hessians generated by the SR1 method show faster progress towards the true Hessian than other methods, but it is more unstable. SR1 is best used within a trust region module.
515
+
516
+ .. note::
517
+ SR1 doesn't enforce the hessian estimate to be positive definite, therefore it can generate directions that are not descent directions.
518
+
519
+ .. warning::
520
+ this uses roughly O(N^2) memory.
521
+
522
+ Args:
523
+ init_scale (float | Literal["auto"], optional):
524
+ initial hessian matrix is set to identity times this.
525
+
526
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
527
+
528
+ Defaults to "auto".
529
+ tol (float, optional):
530
+ tolerance for denominator in SR1 update rule as in Nocedal, Wright. “Numerical optimization” 2nd p.146. Defaults to 1e-8.
531
+ ptol (float | None, optional):
532
+ skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
533
+ Defaults to 1e-10.
534
+ ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
535
+ reset_interval (int | None | Literal["auto"], optional):
536
+ interval between resetting the hessian approximation.
537
+
538
+ "auto" corresponds to number of decision variables + 1.
539
+
540
+ None - no resets.
541
+
542
+ Defaults to None.
543
+ beta (float | None, optional): momentum on H or B. Defaults to None.
544
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
545
+ scale_first (bool, optional):
546
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
547
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
548
+ concat_params (bool, optional):
549
+ If true, all parameters are treated as a single vector.
550
+ If False, the update rule is applied to each parameter separately. Defaults to True.
551
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
552
+
553
+ Examples:
554
+ SR1 with strong-wolfe line search
555
+
556
+ .. code-block:: python
557
+
558
+ opt = tz.Modular(
559
+ model.parameters(),
560
+ tz.m.SR1(),
561
+ tz.m.StrongWolfe()
562
+ )
563
+
564
+ BFGS preconditioning applied to momentum
565
+
566
+ .. code-block:: python
567
+
568
+ opt = tz.Modular(
569
+ model.parameters(),
570
+ tz.m.SR1(inner=tz.m.EMA(0.9)),
571
+ tz.m.LR(1e-2)
572
+ )
573
+ """
574
+
575
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
576
+ return sr1_(H=H, s=s, y=y, tol=setting['tol'])
577
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
578
+ return sr1_(H=B, s=y, y=s, tol=setting['tol'])
579
+
200
580
 
201
- # BFGS has defaults - init_scale = "auto" and scale_second = False
202
- # SR1 has defaults - init_scale = 1 and scale_second = True
203
- # basically some methods work better with first and some with second.
204
- # I inherit from BFGS or SR1 to avoid writing all those arguments again
205
581
  # ------------------------------------ DFP ----------------------------------- #
206
582
  def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
207
- sy = torch.dot(s, y)
583
+ sy = s.dot(y)
208
584
  if sy.abs() <= tol: return H
209
- term1 = torch.outer(s, s).div_(sy)
210
- denom = torch.dot(y, H @ y) #
211
- if denom.abs() <= tol: return H
212
- num = H @ torch.outer(y, y) @ H
213
- term2 = num.div_(denom)
585
+ term1 = s.outer(s).div_(sy)
586
+
587
+ yHy = _safe_clip(y.dot(H @ y))
588
+
589
+ num = (H @ y).outer(y) @ H
590
+ term2 = num.div_(yHy)
591
+
214
592
  H += term1.sub_(term2)
215
593
  return H
216
594
 
217
- class DFP(HUpdateStrategy):
218
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
219
- return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
595
+ def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
596
+ sy = s.dot(y)
597
+ if sy.abs() <= tol: return B
598
+ I = torch.eye(B.size(0), device=B.device, dtype=B.dtype)
599
+ sub = y.outer(s).div_(sy)
600
+ term1 = I - sub
601
+ term2 = I.sub_(sub.T)
602
+ term3 = y.outer(y).div_(sy)
603
+ B = (term1 @ B @ term2).add_(term3)
604
+ return B
605
+
606
+
607
+ class DFP(_InverseHessianUpdateStrategyDefaults):
608
+ """Davidon–Fletcher–Powell Quasi-Newton method.
609
+
610
+ .. note::
611
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
612
+
613
+ .. note::
614
+ BFGS is the recommended QN method and will usually outperform this.
615
+
616
+ .. warning::
617
+ this uses roughly O(N^2) memory.
618
+
619
+ """
620
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
621
+ return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
622
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
623
+ return dfp_B(B=B, s=s, y=y, tol=setting['tol'])
220
624
 
221
625
 
222
626
  # formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
223
627
  # H' = H - (Hy - S)c^T / c^T*y
224
628
  # the difference is how `c` is calculated
225
629
 
226
- def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
630
+ def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
227
631
  c = H.T @ s
228
- denom = c.dot(y)
229
- if denom.abs() <= tol: return H
632
+ cy = _safe_clip(c.dot(y))
230
633
  num = (H@y).sub_(s).outer(c)
231
- H -= num/denom
634
+ H -= num/cy
232
635
  return H
636
+ def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
637
+ r = y - B@s
638
+ ss = _safe_clip(s.dot(s))
639
+ B += r.outer(s).div_(ss)
640
+ return B
233
641
 
234
- def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
235
- c = y
236
- denom = c.dot(y)
237
- if denom.abs() <= tol: return H
238
- num = (H@y).sub_(s).outer(c)
239
- H -= num/denom
642
+ def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
643
+ yy = _safe_clip(y.dot(y))
644
+ num = (s - (H @ y)).outer(y)
645
+ H += num/yy
240
646
  return H
647
+ def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
648
+ r = y - B@s
649
+ ys = _safe_clip(y.dot(s))
650
+ B += r.outer(y).div_(ys)
651
+ return B
241
652
 
242
- def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
653
+ def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
243
654
  c = g_prev
244
- denom = c.dot(y)
245
- if denom.abs() <= tol: return H
655
+ cy = _safe_clip(c.dot(y))
246
656
  num = (H@y).sub_(s).outer(c)
247
- H -= num/denom
657
+ H -= num/cy
248
658
  return H
249
659
 
250
- def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
251
- c = torch.linalg.multi_dot([H,H,y]) # pylint:disable=not-callable
252
- denom = c.dot(y)
253
- if denom.abs() <= tol: return H
254
- num = (H@y).sub_(s).outer(c)
255
- H -= num/denom
660
+ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
661
+ Hy = H @ y
662
+ c = H @ Hy # pylint:disable=not-callable
663
+ cy = _safe_clip(c.dot(y))
664
+ num = Hy.sub_(s).outer(c)
665
+ H -= num/cy
256
666
  return H
257
667
 
258
- class BroydenGood(HUpdateStrategy):
259
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
260
- return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
668
+ class BroydenGood(_InverseHessianUpdateStrategyDefaults):
669
+ """Broyden's "good" Quasi-Newton method.
261
670
 
262
- class BroydenBad(HUpdateStrategy):
263
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
264
- return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
671
+ .. note::
672
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
265
673
 
266
- class Greenstadt1(HUpdateStrategy):
267
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
268
- return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
674
+ .. note::
675
+ BFGS is the recommended QN method and will usually outperform this.
269
676
 
270
- class Greenstadt2(HUpdateStrategy):
271
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
272
- return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
677
+ .. warning::
678
+ this uses roughly O(N^2) memory.
679
+
680
+ Reference:
681
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
682
+ """
683
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
684
+ return broyden_good_H_(H=H, s=s, y=y)
685
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
686
+ return broyden_good_B_(B=B, s=s, y=y)
273
687
 
688
+ class BroydenBad(_InverseHessianUpdateStrategyDefaults):
689
+ """Broyden's "bad" Quasi-Newton method.
274
690
 
275
- def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
276
- n = H.shape[0]
691
+ .. note::
692
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
277
693
 
694
+ .. note::
695
+ BFGS is the recommended QN method and will usually outperform this.
696
+
697
+ .. warning::
698
+ this uses roughly O(N^2) memory.
699
+
700
+ Reference:
701
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
702
+ """
703
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
704
+ return broyden_bad_H_(H=H, s=s, y=y)
705
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
706
+ return broyden_bad_B_(B=B, s=s, y=y)
707
+
708
+ class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
709
+ """Greenstadt's first Quasi-Newton method.
710
+
711
+ .. note::
712
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
713
+
714
+ .. note::
715
+ BFGS is the recommended QN method and will usually outperform this.
716
+
717
+ .. warning::
718
+ this uses roughly O(N^2) memory.
719
+
720
+ Reference:
721
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
722
+ """
723
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
724
+ return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)
725
+
726
+ class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
727
+ """Greenstadt's second Quasi-Newton method.
728
+
729
+ .. note::
730
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
731
+
732
+ .. note::
733
+ BFGS is the recommended QN method and will usually outperform this.
734
+
735
+ .. warning::
736
+ this uses roughly O(N^2) memory.
737
+
738
+ Reference:
739
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
740
+
741
+ """
742
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
743
+ return greenstadt2_H_(H=H, s=s, y=y)
744
+
745
+
746
+ def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
278
747
  j = y.abs().argmax()
279
- u = torch.zeros(n, device=H.device, dtype=H.dtype)
280
- u[j] = 1.0
281
748
 
282
- denom = y[j]
283
- if denom.abs() < tol: return H
749
+ denom = _safe_clip(y[j])
284
750
 
285
751
  Hy = H @ y.unsqueeze(1)
286
752
  num = s.unsqueeze(1) - Hy
@@ -288,51 +754,178 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
288
754
  H[:, j] += num.squeeze() / denom
289
755
  return H
290
756
 
291
- class ColumnUpdatingMethod(HUpdateStrategy):
292
- """Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
293
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
294
- return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
757
+ class ICUM(_InverseHessianUpdateStrategyDefaults):
758
+ """
759
+ Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
760
+ due to only updating one column of the inverse hessian approximation per step.
761
+
762
+ .. note::
763
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
764
+
765
+ .. warning::
766
+ this uses roughly O(N^2) memory.
295
767
 
296
- def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
768
+ Reference:
769
+ Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
770
+ """
771
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
772
+ return icum_H_(H=H, s=s, y=y)
773
+
774
+ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
297
775
  s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
298
776
  I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
299
777
  d = (R + I * (s_norm/2)) @ s
300
- denom = d.dot(s)
301
- if denom.abs() <= tol: return H, R
302
- R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(denom)))
778
+ ds = _safe_clip(d.dot(s))
779
+ R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
303
780
 
304
781
  c = H.T @ d
305
- denom = c.dot(y)
306
- if denom.abs() <= tol: return H, R
782
+ cy = _safe_clip(c.dot(y))
307
783
  num = (H@y).sub_(s).outer(c)
308
- H -= num/denom
784
+ H -= num/cy
309
785
  return H, R
310
786
 
311
- class ThomasOptimalMethod(HUpdateStrategy):
312
- """Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
313
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
787
+ class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
788
+ """
789
+ Thomas's "optimal" Quasi-Newton method.
790
+
791
+ .. note::
792
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
793
+
794
+ .. note::
795
+ BFGS is the recommended QN method and will usually outperform this.
796
+
797
+ .. warning::
798
+ this uses roughly O(N^2) memory.
799
+
800
+ Reference:
801
+ Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
802
+ """
803
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
314
804
  if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
315
- H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
805
+ H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
316
806
  return H
317
807
 
808
+ def _reset_M_(self, M, s, y,inverse, init_scale, state):
809
+ super()._reset_M_(M, s, y, inverse, init_scale, state)
810
+ for st in self.state.values():
811
+ st.pop("R", None)
812
+
318
813
  # ------------------------ powell's symmetric broyden ------------------------ #
319
- def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
814
+ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
320
815
  y_Bs = y - B@s
321
- ss = s.dot(s)
322
- if ss.abs() < tol: return B
816
+ ss = _safe_clip(s.dot(s))
323
817
  num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
324
818
  term1 = num1.div_(ss)
325
- term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
819
+ term2 = s.outer(s).mul_(y_Bs.dot(s)/(_safe_clip(ss**2)))
326
820
  B += term1.sub_(term2)
327
821
  return B
328
822
 
329
- class PSB(HessianUpdateStrategy):
823
+ # I couldn't find formula for H
824
+ class PSB(_HessianUpdateStrategyDefaults):
825
+ """Powell's Symmetric Broyden Quasi-Newton method.
826
+
827
+ .. note::
828
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
829
+
830
+ .. note::
831
+ BFGS is the recommended QN method and will usually outperform this.
832
+
833
+ .. warning::
834
+ this uses roughly O(N^2) memory.
835
+
836
+ Reference:
837
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
838
+ """
839
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
840
+ return psb_B_(B=B, s=s, y=y)
841
+
842
+
843
+ # Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
844
+ def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
845
+ Hy = H@y
846
+ yHy = _safe_clip(y.dot(Hy))
847
+ num = (s - Hy).outer(Hy)
848
+ H += num.div_(yHy)
849
+ return H
850
+
851
+ class Pearson(_InverseHessianUpdateStrategyDefaults):
852
+ """
853
+ Pearson's Quasi-Newton method.
854
+
855
+ .. note::
856
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
857
+
858
+ .. note::
859
+ BFGS is the recommended QN method and will usually outperform this.
860
+
861
+ .. warning::
862
+ this uses roughly O(N^2) memory.
863
+
864
+ Reference:
865
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
866
+ """
867
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
868
+ return pearson_H_(H=H, s=s, y=y)
869
+
870
+ def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
871
+ sy = _safe_clip(s.dot(y))
872
+ num = (s - H@y).outer(s)
873
+ H += num.div_(sy)
874
+ return H
875
+
876
+ class McCormick(_InverseHessianUpdateStrategyDefaults):
877
+ """McCormicks's Quasi-Newton method.
878
+
879
+ .. note::
880
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
881
+
882
+ .. note::
883
+ BFGS is the recommended QN method and will usually outperform this.
884
+
885
+ .. warning::
886
+ this uses roughly O(N^2) memory.
887
+
888
+ Reference:
889
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
890
+
891
+ This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
892
+ """
893
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
894
+ return mccormick_H_(H=H, s=s, y=y)
895
+
896
+ def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
897
+ Hy = H @ y
898
+ yHy = _safe_clip(y.dot(Hy))
899
+ H -= Hy.outer(Hy) / yHy
900
+ R += (s - R@y).outer(Hy) / yHy
901
+ return H, R
902
+
903
+ class ProjectedNewtonRaphson(HessianUpdateStrategy):
904
+ """
905
+ Projected Newton Raphson method.
906
+
907
+ .. note::
908
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
909
+
910
+ .. note::
911
+ this is an experimental method.
912
+
913
+ .. warning::
914
+ this uses roughly O(N^2) memory.
915
+
916
+ Reference:
917
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
918
+
919
+ This one is Algorithm 7.
920
+ """
330
921
  def __init__(
331
922
  self,
332
923
  init_scale: float | Literal["auto"] = 'auto',
333
- tol: float = 1e-10,
334
- tol_reset: bool = True,
335
- reset_interval: int | None = None,
924
+ tol: float = 1e-8,
925
+ ptol: float | None = 1e-10,
926
+ ptol_reset: bool = False,
927
+ gtol: float | None = 1e-10,
928
+ reset_interval: int | None | Literal['auto'] = 'auto',
336
929
  beta: float | None = None,
337
930
  update_freq: int = 1,
338
931
  scale_first: bool = True,
@@ -341,34 +934,30 @@ class PSB(HessianUpdateStrategy):
341
934
  inner: Chainable | None = None,
342
935
  ):
343
936
  super().__init__(
344
- defaults=None,
345
937
  init_scale=init_scale,
346
938
  tol=tol,
347
- tol_reset=tol_reset,
939
+ ptol = ptol,
940
+ ptol_reset=ptol_reset,
941
+ gtol=gtol,
348
942
  reset_interval=reset_interval,
349
943
  beta=beta,
350
944
  update_freq=update_freq,
351
945
  scale_first=scale_first,
352
946
  scale_second=scale_second,
353
947
  concat_params=concat_params,
354
- inverse=False,
948
+ inverse=True,
355
949
  inner=inner,
356
950
  )
357
951
 
358
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
359
- return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
360
-
361
- def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
362
- sy = s.dot(y)
363
- if sy.abs() <= tol: return H
364
- num = (s - H@y).outer(s)
365
- H += num.div_(sy)
366
- return H
952
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
953
+ if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
954
+ H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
955
+ state["R"] = R
956
+ return H
367
957
 
368
- class Pearson2(HUpdateStrategy):
369
- """finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
370
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
371
- return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
958
+ def _reset_M_(self, M, s, y, inverse, init_scale, state):
959
+ assert inverse
960
+ M.copy_(state["R"])
372
961
 
373
962
  # Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
374
963
  def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
@@ -380,12 +969,10 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
380
969
  # however p.12 says eps = gs / gHy
381
970
 
382
971
  Hy = H@y
383
- gHy = g.dot(Hy)
384
- yHy = y.dot(Hy)
972
+ gHy = _safe_clip(g.dot(Hy))
973
+ yHy = _safe_clip(y.dot(Hy))
385
974
  sy = s.dot(y)
386
- if sy < tol: return H
387
- if yHy.abs() < tol: return H
388
- if gHy.abs() < tol: return H
975
+ if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
389
976
 
390
977
  v_mul = yHy.sqrt()
391
978
  v_term1 = s/sy
@@ -400,28 +987,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
400
987
  e = gs / gHy
401
988
  if switch in (1, 3):
402
989
  if e/o <= 1:
403
- if o.abs() <= tol: return H
404
- phi = e/o
990
+ phi = e/_safe_clip(o)
405
991
  theta = 0
406
992
  elif o/t >= 1:
407
- if t.abs() <= tol: return H
408
- phi = o/t
993
+ phi = o/_safe_clip(t)
409
994
  theta = 1
410
995
  else:
411
996
  phi = 1
412
- denom = e*t - o**2
413
- if denom.abs() <= tol: return H
997
+ denom = _safe_clip(e*t - o**2)
414
998
  if switch == 1: theta = o * (e - o) / denom
415
999
  else: theta = o * (t - o) / denom
416
1000
 
417
1001
  elif switch == 2:
418
- if t.abs() <= tol or o.abs() <= tol or e.abs() <= tol: return H
1002
+ t = _safe_clip(t)
1003
+ o = _safe_clip(o)
1004
+ e = _safe_clip(e)
419
1005
  phi = (e / t) ** 0.5
420
1006
  theta = 1 / (1 + (t*e / o**2)**0.5)
421
1007
 
422
1008
  elif switch == 4:
423
- if t.abs() <= tol: return H
424
- phi = e/t
1009
+ phi = e/_safe_clip(t)
425
1010
  theta = 1/2
426
1011
 
427
1012
  else: raise ValueError(switch)
@@ -440,14 +1025,29 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
440
1025
 
441
1026
 
442
1027
  class SSVM(HessianUpdateStrategy):
443
- """This one is from Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
1028
+ """
1029
+ Self-scaling variable metric Quasi-Newton method.
1030
+
1031
+ .. note::
1032
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1033
+
1034
+ .. note::
1035
+ BFGS is the recommended QN method and will usually outperform this.
1036
+
1037
+ .. warning::
1038
+ this uses roughly O(N^2) memory.
1039
+
1040
+ Reference:
1041
+ Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
444
1042
  """
445
1043
  def __init__(
446
1044
  self,
447
1045
  switch: tuple[float,float] | Literal[1,2,3,4] = 3,
448
1046
  init_scale: float | Literal["auto"] = 'auto',
449
- tol: float = 1e-10,
450
- tol_reset: bool = True,
1047
+ tol: float = 1e-8,
1048
+ ptol: float | None = 1e-10,
1049
+ ptol_reset: bool = False,
1050
+ gtol: float | None = 1e-10,
451
1051
  reset_interval: int | None = None,
452
1052
  beta: float | None = None,
453
1053
  update_freq: int = 1,
@@ -461,7 +1061,262 @@ class SSVM(HessianUpdateStrategy):
461
1061
  defaults=defaults,
462
1062
  init_scale=init_scale,
463
1063
  tol=tol,
464
- tol_reset=tol_reset,
1064
+ ptol=ptol,
1065
+ ptol_reset=ptol_reset,
1066
+ gtol=gtol,
1067
+ reset_interval=reset_interval,
1068
+ beta=beta,
1069
+ update_freq=update_freq,
1070
+ scale_first=scale_first,
1071
+ scale_second=scale_second,
1072
+ concat_params=concat_params,
1073
+ inverse=True,
1074
+ inner=inner,
1075
+ )
1076
+
1077
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1078
+ return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])
1079
+
1080
+ # HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1081
+ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
1082
+ Hy = H@y
1083
+ ys = y.dot(s)
1084
+ if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
1085
+ yHy = y.dot(Hy)
1086
+ denom = _safe_clip(ys + yHy)
1087
+
1088
+ term1 = 1/denom
1089
+ term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
1090
+ term3 = s.outer(y) @ H
1091
+ term4 = Hy.outer(s)
1092
+ term5 = Hy.outer(y) @ H
1093
+
1094
+ inner_term = term2 - term3 - term4 - term5
1095
+ H += inner_term.mul_(term1)
1096
+ return H
1097
+
1098
+ def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
1099
+ sy = _safe_clip(s.dot(y))
1100
+ return g - (y * (s.dot(g) / sy))
1101
+
1102
+
1103
+ class GradientCorrection(Transform):
1104
+ """
1105
+ Estimates gradient at minima along search direction assuming function is quadratic.
1106
+
1107
+ This can useful as inner module for second order methods with inexact line search.
1108
+
1109
+ Example:
1110
+ L-BFGS with gradient correction
1111
+
1112
+ .. code-block :: python
1113
+
1114
+ opt = tz.Modular(
1115
+ model.parameters(),
1116
+ tz.m.LBFGS(inner=tz.m.GradientCorrection()),
1117
+ tz.m.Backtracking()
1118
+ )
1119
+
1120
+ Reference:
1121
+ HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1122
+
1123
+ """
1124
+ def __init__(self):
1125
+ super().__init__(None, uses_grad=False)
1126
+
1127
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
1128
+ if 'p_prev' not in states[0]:
1129
+ p_prev = unpack_states(states, tensors, 'p_prev', init=params)
1130
+ g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
1131
+ return tensors
1132
+
1133
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
1134
+ g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)
1135
+
1136
+ p_prev.copy_(params)
1137
+ g_prev.copy_(tensors)
1138
+ return g_hat
1139
+
1140
+ class Horisho(_InverseHessianUpdateStrategyDefaults):
1141
+ """
1142
+ Horisho's variable metric Quasi-Newton method.
1143
+
1144
+ .. note::
1145
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1146
+
1147
+ .. note::
1148
+ BFGS is the recommended QN method and will usually outperform this.
1149
+
1150
+ .. warning::
1151
+ this uses roughly O(N^2) memory.
1152
+
1153
+ Reference:
1154
+ HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1155
+ """
1156
+
1157
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1158
+ return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])
1159
+
1160
+ # Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
1161
+ def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
1162
+ sy = s.dot(y)
1163
+ if sy.abs() < tol: return H # part of algorithm
1164
+ Hy = H @ y
1165
+
1166
+ term1 = (s.outer(y) @ H).div_(sy)
1167
+ term2 = (Hy.outer(s)).div_(sy)
1168
+ term3 = 1 + (y.dot(Hy) / sy)
1169
+ term4 = s.outer(s).div_(sy)
1170
+
1171
+ H -= (term1 + term2 - term4.mul_(term3))
1172
+ return H
1173
+
1174
+ class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
1175
+ """
1176
+ Fletcher's variable metric Quasi-Newton method.
1177
+
1178
+ .. note::
1179
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1180
+
1181
+ .. note::
1182
+ BFGS is the recommended QN method and will usually outperform this.
1183
+
1184
+ .. warning::
1185
+ this uses roughly O(N^2) memory.
1186
+
1187
+ Reference:
1188
+ Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
1189
+ """
1190
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1191
+ return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])
1192
+
1193
+
1194
+ # Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
1195
+ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
1196
+ sy = s.dot(y)
1197
+ if sy < tol: return H # part of algorithm
1198
+
1199
+ term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
1200
+
1201
+ if type == 1:
1202
+ pba = (2*sy + 2*(f-f_prev)) / sy
1203
+
1204
+ elif type == 2:
1205
+ pba = (f_prev - f + 1/(2*sy)) / sy
1206
+
1207
+ else:
1208
+ raise RuntimeError(type)
1209
+
1210
+ term3 = 1/pba + y.dot(H@y) / sy
1211
+ term4 = s.outer(s) / sy
1212
+
1213
+ H.sub_(term1)
1214
+ H.add_(term4.mul_(term3))
1215
+ return H
1216
+
1217
+
1218
+ class NewSSM(HessianUpdateStrategy):
1219
+ """Self-scaling Quasi-Newton method.
1220
+
1221
+ .. note::
1222
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
1223
+
1224
+ .. warning::
1225
+ this uses roughly O(N^2) memory.
1226
+
1227
+ Reference:
1228
+ Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
1229
+ """
1230
+ def __init__(
1231
+ self,
1232
+ type: Literal[1, 2] = 1,
1233
+ init_scale: float | Literal["auto"] = "auto",
1234
+ tol: float = 1e-8,
1235
+ ptol: float | None = 1e-10,
1236
+ ptol_reset: bool = False,
1237
+ gtol: float | None = 1e-10,
1238
+ reset_interval: int | None = None,
1239
+ beta: float | None = None,
1240
+ update_freq: int = 1,
1241
+ scale_first: bool = True,
1242
+ scale_second: bool = False,
1243
+ concat_params: bool = True,
1244
+ inner: Chainable | None = None,
1245
+ ):
1246
+ super().__init__(
1247
+ defaults=dict(type=type),
1248
+ init_scale=init_scale,
1249
+ tol=tol,
1250
+ ptol=ptol,
1251
+ ptol_reset=ptol_reset,
1252
+ gtol=gtol,
1253
+ reset_interval=reset_interval,
1254
+ beta=beta,
1255
+ update_freq=update_freq,
1256
+ scale_first=scale_first,
1257
+ scale_second=scale_second,
1258
+ concat_params=concat_params,
1259
+ inverse=True,
1260
+ inner=inner,
1261
+ )
1262
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1263
+ f = state['f']
1264
+ f_prev = state['f_prev']
1265
+ return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])
1266
+
1267
+ # ---------------------------- Shor’s r-algorithm ---------------------------- #
1268
+ # def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
1269
+ # r = B.T @ y
1270
+ # r /= torch.linalg.vector_norm(r).clip(min=1e-8) # pylint:disable=not-callable
1271
+
1272
+ # I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
1273
+ # return B @ (I - gamma*r.outer(r))
1274
+
1275
+ # this is supposed to be equivalent
1276
+ def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
1277
+ p = H@y
1278
+ #(1-y)^2 (ppT)/(pTq)
1279
+ term = p.outer(p).div_(p.dot(y).clip(min=1e-8))
1280
+ H.sub_(term, alpha=1-alpha**2)
1281
+ return H
1282
+
1283
+ class ShorR(HessianUpdateStrategy):
1284
+ """Shor’s r-algorithm.
1285
+
1286
+ .. note::
1287
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
1288
+
1289
+ Reference:
1290
+ Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.
1291
+
1292
+ Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
1293
+ """
1294
+
1295
+ def __init__(
1296
+ self,
1297
+ alpha=0.5,
1298
+ init_scale: float | Literal["auto"] = 1,
1299
+ tol: float = 1e-8,
1300
+ ptol: float | None = 1e-10,
1301
+ ptol_reset: bool = False,
1302
+ gtol: float | None = 1e-10,
1303
+ reset_interval: int | None | Literal['auto'] = None,
1304
+ beta: float | None = None,
1305
+ update_freq: int = 1,
1306
+ scale_first: bool = False,
1307
+ scale_second: bool = False,
1308
+ concat_params: bool = True,
1309
+ # inverse: bool = True,
1310
+ inner: Chainable | None = None,
1311
+ ):
1312
+ defaults = dict(alpha=alpha)
1313
+ super().__init__(
1314
+ defaults=defaults,
1315
+ init_scale=init_scale,
1316
+ tol=tol,
1317
+ ptol=ptol,
1318
+ ptol_reset=ptol_reset,
1319
+ gtol=gtol,
465
1320
  reset_interval=reset_interval,
466
1321
  beta=beta,
467
1322
  update_freq=update_freq,
@@ -472,5 +1327,5 @@ class SSVM(HessianUpdateStrategy):
472
1327
  inner=inner,
473
1328
  )
474
1329
 
475
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
476
- return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
1330
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1331
+ return shor_r_(H=H, y=y, alpha=setting['alpha'])