torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  """Use BFGS or maybe SR1."""
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import Mapping
3
+ from collections.abc import Mapping, Callable
4
4
  from typing import Any, Literal
5
+ import warnings
5
6
 
6
7
  import torch
7
8
 
8
9
  from ...core import Chainable, Module, TensorwiseTransform, Transform
9
10
  from ...utils import TensorList, set_storage_, unpack_states
11
+ from ..functional import safe_scaling_
10
12
 
11
13
 
12
14
  def _safe_dict_update_(d1_:dict, d2:dict):
@@ -19,13 +21,111 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
19
21
  elif state[key].shape != value.shape: state[key] = value
20
22
  else: state[key].lerp_(value, 1-beta)
21
23
 
24
+ def _safe_clip(x: torch.Tensor):
25
+ """makes sure scalar tensor x is not smaller than epsilon"""
26
+ assert x.numel() == 1, x.shape
27
+ eps = torch.finfo(x.dtype).eps ** 2
28
+ if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
29
+ return x
30
+
22
31
  class HessianUpdateStrategy(TensorwiseTransform, ABC):
32
+ """Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
33
+
34
+ This is an abstract class, to use it, subclass it and override `update_H` and/or `update_B`.
35
+
36
+ Args:
37
+ defaults (dict | None, optional): defaults. Defaults to None.
38
+ init_scale (float | Literal["auto"], optional):
39
+ initial hessian matrix is set to identity times this.
40
+
41
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
42
+
43
+ Defaults to "auto".
44
+ tol (float, optional):
45
+ algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-8.
46
+ ptol (float | None, optional):
47
+ tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
48
+ ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
49
+ gtol (float | None, optional):
50
+ tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
51
+ reset_interval (int | None | Literal["auto"], optional):
52
+ interval between resetting the hessian approximation.
53
+
54
+ "auto" corresponds to number of decision variables + 1.
55
+
56
+ None - no resets.
57
+
58
+ Defaults to None.
59
+ beta (float | None, optional): momentum on H or B. Defaults to None.
60
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
61
+ scale_first (bool, optional):
62
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
63
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
64
+ concat_params (bool, optional):
65
+ If true, all parameters are treated as a single vector.
66
+ If False, the update rule is applied to each parameter separately. Defaults to True.
67
+ inverse (bool, optional):
68
+ set to True if this method uses hessian inverse approximation H and has `update_H` method.
69
+ set to False if this maintains hessian approximation B and has `update_B method`.
70
+ Defaults to True.
71
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
72
+
73
+ Example:
74
+ Implementing BFGS method that maintains an estimate of the hessian inverse (H):
75
+
76
+ .. code-block:: python
77
+
78
+ class BFGS(HessianUpdateStrategy):
79
+ def __init__(
80
+ self,
81
+ init_scale: float | Literal["auto"] = "auto",
82
+ tol: float = 1e-8,
83
+ ptol: float = 1e-10,
84
+ ptol_reset: bool = False,
85
+ reset_interval: int | None = None,
86
+ beta: float | None = None,
87
+ update_freq: int = 1,
88
+ scale_first: bool = True,
89
+ scale_second: bool = False,
90
+ concat_params: bool = True,
91
+ inner: Chainable | None = None,
92
+ ):
93
+ super().__init__(
94
+ defaults=None,
95
+ init_scale=init_scale,
96
+ tol=tol,
97
+ ptol=ptol,
98
+ ptol_reset=ptol_reset,
99
+ reset_interval=reset_interval,
100
+ beta=beta,
101
+ update_freq=update_freq,
102
+ scale_first=scale_first,
103
+ scale_second=scale_second,
104
+ concat_params=concat_params,
105
+ inverse=True,
106
+ inner=inner,
107
+ )
108
+
109
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
110
+ tol = settings["tol"]
111
+ sy = torch.dot(s, y)
112
+ if sy <= tol: return H
113
+ num1 = (sy + (y @ H @ y)) * s.outer(s)
114
+ term1 = num1.div_(sy**2)
115
+ num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
116
+ term2 = num2.div_(sy)
117
+ H += term1.sub_(term2)
118
+ return H
119
+
120
+ """
23
121
  def __init__(
24
122
  self,
25
123
  defaults: dict | None = None,
26
124
  init_scale: float | Literal["auto"] = "auto",
27
- tol: float = 1e-10,
28
- tol_reset: bool = True,
125
+ tol: float = 1e-8,
126
+ ptol: float | None = 1e-10,
127
+ ptol_reset: bool = False,
128
+ gtol: float | None = 1e-10,
29
129
  reset_interval: int | None | Literal['auto'] = None,
30
130
  beta: float | None = None,
31
131
  update_freq: int = 1,
@@ -36,9 +136,12 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
36
136
  inner: Chainable | None = None,
37
137
  ):
38
138
  if defaults is None: defaults = {}
39
- _safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, tol_reset=tol_reset, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
139
+ _safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_reset=ptol_reset, gtol=gtol, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
40
140
  super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
41
141
 
142
+ def _init_M(self, size:int, device, dtype, is_inverse:bool):
143
+ return torch.eye(size, device=device, dtype=dtype)
144
+
42
145
  def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
43
146
  """returns multiplier to H or B"""
44
147
  ys = y.dot(s)
@@ -47,41 +150,83 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
47
150
  return 1
48
151
 
49
152
  def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
50
- set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
153
+ set_storage_(M, self._init_M(s.numel(), device=M.device, dtype=M.dtype, is_inverse=inverse))
51
154
  if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
52
155
  if init_scale >= 1:
53
156
  if inverse: M /= init_scale
54
157
  else: M *= init_scale
55
158
 
56
159
  def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
57
- p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
160
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
58
161
  """update hessian inverse"""
59
162
  raise NotImplementedError
60
163
 
61
164
  def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
62
- p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
165
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
63
166
  """update hessian"""
64
167
  raise NotImplementedError
65
168
 
169
+ def reset_for_online(self):
170
+ super().reset_for_online()
171
+ self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
172
+
173
+ def get_B(self) -> tuple[torch.Tensor, bool]:
174
+ """returns (B or H, is_inverse)."""
175
+ state = next(iter(self.state.values()))
176
+ if "B" in state: return state["B"], False
177
+ return state["H"], True
178
+
179
+ def get_H(self) -> tuple[torch.Tensor, bool]:
180
+ """returns (H or B, is_inverse)."""
181
+ state = next(iter(self.state.values()))
182
+ if "H" in state: return state["H"], False
183
+ return state["B"], True
184
+
185
+ def make_Bv(self) -> Callable[[torch.Tensor], torch.Tensor]:
186
+ B, is_inverse = self.get_B()
187
+
188
+ if is_inverse:
189
+ H=B
190
+ warnings.warn(f'{self} maintains H, so Bv will be inefficient!')
191
+ def Hxv(v): return torch.linalg.solve_ex(H, v)[0] # pylint:disable=not-callable
192
+ return Hxv
193
+
194
+ def Bv(v): return B@v
195
+ return Bv
196
+
197
+ def make_Hv(self) -> Callable[[torch.Tensor], torch.Tensor]:
198
+ H, is_inverse = self.get_H()
199
+
200
+ if is_inverse:
201
+ B=H
202
+ warnings.warn(f'{self} maintains B, so Hv will be inefficient!')
203
+ def Bxv(v): return torch.linalg.solve_ex(B, v)[0] # pylint:disable=not-callable
204
+ return Bxv
205
+
206
+ def Hv(v): return H@v
207
+ return Hv
208
+
66
209
  @torch.no_grad
67
- def update_tensor(self, tensor, param, grad, loss, state, settings):
210
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
68
211
  p = param.view(-1); g = tensor.view(-1)
69
- inverse = settings['inverse']
212
+ inverse = setting['inverse']
70
213
  M_key = 'H' if inverse else 'B'
71
214
  M = state.get(M_key, None)
72
- step = state.get('step', 0)
73
- state['step'] = step + 1
74
- init_scale = settings['init_scale']
75
- tol = settings['tol']
76
- tol_reset = settings['tol_reset']
77
- reset_interval = settings['reset_interval']
215
+ step = state.get('step', 0) + 1
216
+ state['step'] = step
217
+ init_scale = setting['init_scale']
218
+ ptol = setting['ptol']
219
+ ptol_reset = setting['ptol_reset']
220
+ gtol = setting['gtol']
221
+ reset_interval = setting['reset_interval']
78
222
  if reset_interval == 'auto': reset_interval = tensor.numel() + 1
79
223
 
80
- if M is None:
81
- M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
82
- if isinstance(init_scale, (int, float)) and init_scale != 1:
83
- if inverse: M /= init_scale
84
- else: M *= init_scale
224
+ if M is None or 'f_prev' not in state:
225
+ if M is None: # won't be true on reset_for_online
226
+ M = self._init_M(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
227
+ if isinstance(init_scale, (int, float)) and init_scale != 1:
228
+ if inverse: M /= init_scale
229
+ else: M *= init_scale
85
230
 
86
231
  state[M_key] = M
87
232
  state['f_prev'] = loss
@@ -97,190 +242,511 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
97
242
  state['p_prev'].copy_(p)
98
243
  state['g_prev'].copy_(g)
99
244
 
100
- if reset_interval is not None and step != 0 and step % reset_interval == 0:
245
+ if reset_interval is not None and step % reset_interval == 0:
101
246
  self._reset_M_(M, s, y, inverse, init_scale, state)
102
247
  return
103
248
 
104
- # tolerance on gradient difference to avoid exploding after converging
105
- if y.abs().max() <= tol:
106
- # reset history
107
- if tol_reset: self._reset_M_(M, s, y, inverse, init_scale, state)
249
+ # tolerance on parameter difference to avoid exploding after converging
250
+ if ptol is not None and s.abs().max() <= ptol:
251
+ if ptol_reset: self._reset_M_(M, s, y, inverse, init_scale, state) # reset history
252
+ return
253
+
254
+ # tolerance on gradient difference to avoid exploding when there is no curvature
255
+ if gtol is not None and y.abs().max() <= gtol:
108
256
  return
109
257
 
110
- if step == 1 and init_scale == 'auto':
258
+ if step == 2 and init_scale == 'auto':
111
259
  if inverse: M /= self._get_init_scale(s,y)
112
260
  else: M *= self._get_init_scale(s,y)
113
261
 
114
- beta = settings['beta']
262
+ beta = setting['beta']
115
263
  if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
116
264
 
117
265
  if inverse:
118
- H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
266
+ H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
119
267
  _maybe_lerp_(state, 'H', H_new, beta)
120
268
 
121
269
  else:
122
- B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
270
+ B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
123
271
  _maybe_lerp_(state, 'B', B_new, beta)
124
272
 
125
273
  state['f_prev'] = loss
126
274
 
275
+ def _post_B(self, B: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
276
+ """modifies B before appling the update rule. Must return (B, g)"""
277
+ return B, g
278
+
279
+ def _post_H(self, H: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
280
+ """modifies H before appling the update rule. Must return (H, g)"""
281
+ return H, g
282
+
127
283
  @torch.no_grad
128
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
284
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
129
285
  step = state.get('step', 0)
130
286
 
131
- if settings['scale_second'] and step == 2:
132
- scale_factor = 1 / tensor.abs().sum().clip(min=1)
133
- scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
134
- tensor = tensor * scale_factor
287
+ if setting['scale_second'] and step == 2:
288
+ tensor = safe_scaling_(tensor)
135
289
 
136
- inverse = settings['inverse']
290
+ inverse = setting['inverse']
137
291
  if inverse:
138
292
  H = state['H']
139
- return (H @ tensor.view(-1)).view_as(tensor)
293
+ H, g = self._post_H(H, tensor.view(-1), state, setting)
294
+ if H.ndim == 1: return g.mul_(H).view_as(tensor)
295
+ return (H @ g).view_as(tensor)
140
296
 
141
297
  B = state['B']
298
+ H, g = self._post_B(B, tensor.view(-1), state, setting)
299
+
300
+ if B.ndim == 1: return g.div_(B).view_as(tensor)
301
+ x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
302
+ if info == 0: return x.view_as(tensor)
303
+ return safe_scaling_(tensor)
304
+
305
+ class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
306
+ '''This is :code:`HessianUpdateStrategy` subclass for algorithms with no extra defaults, to skip the lengthy __init__.
307
+ Refer to :code:`HessianUpdateStrategy` documentation.
308
+
309
+ Example:
310
+ Implementing BFGS method that maintains an estimate of the hessian inverse (H):
311
+
312
+ .. code-block:: python
313
+
314
+ class BFGS(_HessianUpdateStrategyDefaults):
315
+ """Broyden–Fletcher–Goldfarb–Shanno algorithm"""
316
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
317
+ tol = settings["tol"]
318
+ sy = torch.dot(s, y)
319
+ if sy <= tol: return H
320
+ num1 = (sy + (y @ H @ y)) * s.outer(s)
321
+ term1 = num1.div_(sy**2)
322
+ num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
323
+ term2 = num2.div_(sy)
324
+ H += term1.sub_(term2)
325
+ return H
326
+
327
+ Make sure to put at least a basic class level docstring to overwrite this.
328
+ '''
329
+ def __init__(
330
+ self,
331
+ init_scale: float | Literal["auto"] = "auto",
332
+ tol: float = 1e-8,
333
+ ptol: float | None = 1e-10,
334
+ ptol_reset: bool = False,
335
+ gtol: float | None = 1e-10,
336
+ reset_interval: int | None = None,
337
+ beta: float | None = None,
338
+ update_freq: int = 1,
339
+ scale_first: bool = True,
340
+ scale_second: bool = False,
341
+ concat_params: bool = True,
342
+ inverse: bool = True,
343
+ inner: Chainable | None = None,
344
+ ):
345
+ super().__init__(
346
+ defaults=None,
347
+ init_scale=init_scale,
348
+ tol=tol,
349
+ ptol=ptol,
350
+ ptol_reset=ptol_reset,
351
+ gtol=gtol,
352
+ reset_interval=reset_interval,
353
+ beta=beta,
354
+ update_freq=update_freq,
355
+ scale_first=scale_first,
356
+ scale_second=scale_second,
357
+ concat_params=concat_params,
358
+ inverse=inverse,
359
+ inner=inner,
360
+ )
142
361
 
143
- return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
144
-
145
- # to avoid typing all arguments for each method
146
- class HUpdateStrategy(HessianUpdateStrategy):
362
+ class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
147
363
  def __init__(
148
364
  self,
149
365
  init_scale: float | Literal["auto"] = "auto",
150
- tol: float = 1e-10,
151
- tol_reset: bool = True,
366
+ tol: float = 1e-8,
367
+ ptol: float | None = 1e-10,
368
+ ptol_reset: bool = False,
369
+ gtol: float | None = 1e-10,
152
370
  reset_interval: int | None = None,
153
371
  beta: float | None = None,
154
372
  update_freq: int = 1,
155
373
  scale_first: bool = True,
156
374
  scale_second: bool = False,
157
375
  concat_params: bool = True,
376
+ inverse: bool = False,
158
377
  inner: Chainable | None = None,
159
378
  ):
160
379
  super().__init__(
161
380
  defaults=None,
162
381
  init_scale=init_scale,
163
382
  tol=tol,
164
- tol_reset=tol_reset,
383
+ ptol=ptol,
384
+ ptol_reset=ptol_reset,
385
+ gtol=gtol,
165
386
  reset_interval=reset_interval,
166
387
  beta=beta,
167
388
  update_freq=update_freq,
168
389
  scale_first=scale_first,
169
390
  scale_second=scale_second,
170
391
  concat_params=concat_params,
171
- inverse=True,
392
+ inverse=inverse,
172
393
  inner=inner,
173
394
  )
395
+
174
396
  # ----------------------------------- BFGS ----------------------------------- #
397
+ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
398
+ sy = s.dot(y)
399
+ if sy < tol: return B
400
+
401
+ Bs = B@s
402
+ sBs = _safe_clip(s.dot(Bs))
403
+
404
+ term1 = y.outer(y).div_(sy)
405
+ term2 = (Bs.outer(s) @ B.T).div_(sBs)
406
+ B += term1.sub_(term2)
407
+ return B
408
+
175
409
  def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
176
- sy = torch.dot(s, y)
177
- if sy <= tol: return H # don't reset H in this case
178
- num1 = (sy + (y @ H @ y)) * s.outer(s)
179
- term1 = num1.div_(sy**2)
180
- num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
410
+ sy = s.dot(y)
411
+ if sy <= tol: return H
412
+
413
+ sy_sq = _safe_clip(sy**2)
414
+
415
+ Hy = H@y
416
+ scale1 = (sy + y.dot(Hy)) / sy_sq
417
+ term1 = s.outer(s).mul_(scale1)
418
+
419
+ num2 = (Hy.outer(s)).add_(s.outer(y @ H))
181
420
  term2 = num2.div_(sy)
421
+
182
422
  H += term1.sub_(term2)
183
423
  return H
184
424
 
185
- class BFGS(HUpdateStrategy):
186
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
187
- return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
425
+ class BFGS(_InverseHessianUpdateStrategyDefaults):
426
+ """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
427
+
428
+ .. note::
429
+ a line search such as :code:`tz.m.StrongWolfe()` is recommended, although this can be stable without a line search. Alternatively warmup :code:`tz.m.Warmup` can stabilize quasi-newton methods without line search.
430
+
431
+ .. warning::
432
+ this uses roughly O(N^2) memory.
433
+
434
+ Args:
435
+ init_scale (float | Literal["auto"], optional):
436
+ initial hessian matrix is set to identity times this.
437
+
438
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
439
+
440
+ Defaults to "auto".
441
+ tol (float, optional):
442
+ tolerance on curvature condition. Defaults to 1e-8.
443
+ ptol (float | None, optional):
444
+ skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
445
+ Defaults to 1e-10.
446
+ ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
447
+ reset_interval (int | None | Literal["auto"], optional):
448
+ interval between resetting the hessian approximation.
449
+
450
+ "auto" corresponds to number of decision variables + 1.
451
+
452
+ None - no resets.
453
+
454
+ Defaults to None.
455
+ beta (float | None, optional): momentum on H or B. Defaults to None.
456
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
457
+ scale_first (bool, optional):
458
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
459
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
460
+ concat_params (bool, optional):
461
+ If true, all parameters are treated as a single vector.
462
+ If False, the update rule is applied to each parameter separately. Defaults to True.
463
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
464
+
465
+ Examples:
466
+ BFGS with strong-wolfe line search:
467
+
468
+ .. code-block:: python
469
+
470
+ opt = tz.Modular(
471
+ model.parameters(),
472
+ tz.m.BFGS(),
473
+ tz.m.StrongWolfe()
474
+ )
475
+
476
+ BFGS preconditioning applied to momentum:
477
+
478
+ .. code-block:: python
479
+
480
+ opt = tz.Modular(
481
+ model.parameters(),
482
+ tz.m.BFGS(inner=tz.m.EMA(0.9)),
483
+ tz.m.LR(1e-2)
484
+ )
485
+ """
486
+
487
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
488
+ return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
489
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
490
+ return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])
188
491
 
189
492
  # ------------------------------------ SR1 ----------------------------------- #
190
- def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
493
+ def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
191
494
  z = s - H@y
192
- denom = torch.dot(z, y)
495
+ denom = z.dot(y)
193
496
 
194
497
  z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
195
498
  y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
196
499
 
197
- if y_norm*z_norm < tol: return H
500
+ # if y_norm*z_norm < tol: return H
198
501
 
199
502
  # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
200
503
  if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
201
- H += torch.outer(z, z).div_(denom)
504
+ H += z.outer(z).div_(_safe_clip(denom))
202
505
  return H
203
506
 
204
- class SR1(HUpdateStrategy):
205
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
206
- return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
507
+ class SR1(_InverseHessianUpdateStrategyDefaults):
508
+ """Symmetric Rank 1 Quasi-Newton method.
509
+
510
+ .. note::
511
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
512
+
513
+ .. note::
514
+ approximate Hessians generated by the SR1 method show faster progress towards the true Hessian than other methods, but it is more unstable. SR1 is best used within a trust region module.
515
+
516
+ .. note::
517
+ SR1 doesn't enforce the hessian estimate to be positive definite, therefore it can generate directions that are not descent directions.
518
+
519
+ .. warning::
520
+ this uses roughly O(N^2) memory.
521
+
522
+ Args:
523
+ init_scale (float | Literal["auto"], optional):
524
+ initial hessian matrix is set to identity times this.
525
+
526
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
527
+
528
+ Defaults to "auto".
529
+ tol (float, optional):
530
+ tolerance for denominator in SR1 update rule as in Nocedal, Wright. “Numerical optimization” 2nd p.146. Defaults to 1e-8.
531
+ ptol (float | None, optional):
532
+ skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
533
+ Defaults to 1e-10.
534
+ ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
535
+ reset_interval (int | None | Literal["auto"], optional):
536
+ interval between resetting the hessian approximation.
537
+
538
+ "auto" corresponds to number of decision variables + 1.
539
+
540
+ None - no resets.
541
+
542
+ Defaults to None.
543
+ beta (float | None, optional): momentum on H or B. Defaults to None.
544
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
545
+ scale_first (bool, optional):
546
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
547
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
548
+ concat_params (bool, optional):
549
+ If true, all parameters are treated as a single vector.
550
+ If False, the update rule is applied to each parameter separately. Defaults to True.
551
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
552
+
553
+ Examples:
554
+ SR1 with strong-wolfe line search
555
+
556
+ .. code-block:: python
557
+
558
+ opt = tz.Modular(
559
+ model.parameters(),
560
+ tz.m.SR1(),
561
+ tz.m.StrongWolfe()
562
+ )
563
+
564
+ BFGS preconditioning applied to momentum
565
+
566
+ .. code-block:: python
567
+
568
+ opt = tz.Modular(
569
+ model.parameters(),
570
+ tz.m.SR1(inner=tz.m.EMA(0.9)),
571
+ tz.m.LR(1e-2)
572
+ )
573
+ """
574
+
575
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
576
+ return sr1_(H=H, s=s, y=y, tol=setting['tol'])
577
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
578
+ return sr1_(H=B, s=y, y=s, tol=setting['tol'])
579
+
207
580
 
208
581
  # ------------------------------------ DFP ----------------------------------- #
209
582
  def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
210
- sy = torch.dot(s, y)
583
+ sy = s.dot(y)
211
584
  if sy.abs() <= tol: return H
212
- term1 = torch.outer(s, s).div_(sy)
213
- yHy = torch.dot(y, H @ y) #
214
- if yHy.abs() <= tol: return H
215
- num = H @ torch.outer(y, y) @ H
585
+ term1 = s.outer(s).div_(sy)
586
+
587
+ yHy = _safe_clip(y.dot(H @ y))
588
+
589
+ num = (H @ y).outer(y) @ H
216
590
  term2 = num.div_(yHy)
591
+
217
592
  H += term1.sub_(term2)
218
593
  return H
219
594
 
220
- class DFP(HUpdateStrategy):
221
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
222
- return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
595
+ def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
596
+ sy = s.dot(y)
597
+ if sy.abs() <= tol: return B
598
+ I = torch.eye(B.size(0), device=B.device, dtype=B.dtype)
599
+ sub = y.outer(s).div_(sy)
600
+ term1 = I - sub
601
+ term2 = I.sub_(sub.T)
602
+ term3 = y.outer(y).div_(sy)
603
+ B = (term1 @ B @ term2).add_(term3)
604
+ return B
605
+
606
+
607
+ class DFP(_InverseHessianUpdateStrategyDefaults):
608
+ """Davidon–Fletcher–Powell Quasi-Newton method.
609
+
610
+ .. note::
611
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
612
+
613
+ .. note::
614
+ BFGS is the recommended QN method and will usually outperform this.
615
+
616
+ .. warning::
617
+ this uses roughly O(N^2) memory.
618
+
619
+ """
620
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
621
+ return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
622
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
623
+ return dfp_B(B=B, s=s, y=y, tol=setting['tol'])
223
624
 
224
625
 
225
626
  # formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
226
627
  # H' = H - (Hy - S)c^T / c^T*y
227
628
  # the difference is how `c` is calculated
228
629
 
229
- def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
630
+ def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
230
631
  c = H.T @ s
231
- cy = c.dot(y)
232
- if cy.abs() <= tol: return H
632
+ cy = _safe_clip(c.dot(y))
233
633
  num = (H@y).sub_(s).outer(c)
234
634
  H -= num/cy
235
635
  return H
636
+ def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
637
+ r = y - B@s
638
+ ss = _safe_clip(s.dot(s))
639
+ B += r.outer(s).div_(ss)
640
+ return B
236
641
 
237
- def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
238
- c = y
239
- cy = c.dot(y)
240
- if cy.abs() <= tol: return H
241
- num = (H@y).sub_(s).outer(c)
242
- H -= num/cy
642
+ def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
643
+ yy = _safe_clip(y.dot(y))
644
+ num = (s - (H @ y)).outer(y)
645
+ H += num/yy
243
646
  return H
647
+ def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
648
+ r = y - B@s
649
+ ys = _safe_clip(y.dot(s))
650
+ B += r.outer(y).div_(ys)
651
+ return B
244
652
 
245
- def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
653
+ def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
246
654
  c = g_prev
247
- cy = c.dot(y)
248
- if cy.abs() <= tol: return H
655
+ cy = _safe_clip(c.dot(y))
249
656
  num = (H@y).sub_(s).outer(c)
250
657
  H -= num/cy
251
658
  return H
252
659
 
253
- def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
660
+ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
254
661
  Hy = H @ y
255
662
  c = H @ Hy # pylint:disable=not-callable
256
- cy = c.dot(y)
257
- if cy.abs() <= tol: return H
663
+ cy = _safe_clip(c.dot(y))
258
664
  num = Hy.sub_(s).outer(c)
259
665
  H -= num/cy
260
666
  return H
261
667
 
262
- class BroydenGood(HUpdateStrategy):
263
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
264
- return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
668
+ class BroydenGood(_InverseHessianUpdateStrategyDefaults):
669
+ """Broyden's "good" Quasi-Newton method.
670
+
671
+ .. note::
672
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
673
+
674
+ .. note::
675
+ BFGS is the recommended QN method and will usually outperform this.
676
+
677
+ .. warning::
678
+ this uses roughly O(N^2) memory.
679
+
680
+ Reference:
681
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
682
+ """
683
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
684
+ return broyden_good_H_(H=H, s=s, y=y)
685
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
686
+ return broyden_good_B_(B=B, s=s, y=y)
687
+
688
+ class BroydenBad(_InverseHessianUpdateStrategyDefaults):
689
+ """Broyden's "bad" Quasi-Newton method.
690
+
691
+ .. note::
692
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
693
+
694
+ .. note::
695
+ BFGS is the recommended QN method and will usually outperform this.
696
+
697
+ .. warning::
698
+ this uses roughly O(N^2) memory.
699
+
700
+ Reference:
701
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
702
+ """
703
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
704
+ return broyden_bad_H_(H=H, s=s, y=y)
705
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
706
+ return broyden_bad_B_(B=B, s=s, y=y)
707
+
708
+ class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
709
+ """Greenstadt's first Quasi-Newton method.
710
+
711
+ .. note::
712
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
713
+
714
+ .. note::
715
+ BFGS is the recommended QN method and will usually outperform this.
265
716
 
266
- class BroydenBad(HUpdateStrategy):
267
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
268
- return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
717
+ .. warning::
718
+ this uses roughly O(N^2) memory.
269
719
 
270
- class Greenstadt1(HUpdateStrategy):
271
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
272
- return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
720
+ Reference:
721
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
722
+ """
723
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
724
+ return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)
725
+
726
+ class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
727
+ """Greenstadt's second Quasi-Newton method.
728
+
729
+ .. note::
730
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
731
+
732
+ .. note::
733
+ BFGS is the recommended QN method and will usually outperform this.
273
734
 
274
- class Greenstadt2(HUpdateStrategy):
275
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
276
- return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
735
+ .. warning::
736
+ this uses roughly O(N^2) memory.
277
737
 
738
+ Reference:
739
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
278
740
 
279
- def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
741
+ """
742
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
743
+ return greenstadt2_H_(H=H, s=s, y=y)
744
+
745
+
746
+ def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
280
747
  j = y.abs().argmax()
281
748
 
282
- denom = y[j]
283
- if denom.abs() < tol: return H
749
+ denom = _safe_clip(y[j])
284
750
 
285
751
  Hy = H @ y.unsqueeze(1)
286
752
  num = s.unsqueeze(1) - Hy
@@ -288,31 +754,55 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
288
754
  H[:, j] += num.squeeze() / denom
289
755
  return H
290
756
 
291
- class ColumnUpdatingMethod(HUpdateStrategy):
292
- """Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
293
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
294
- return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
757
+ class ICUM(_InverseHessianUpdateStrategyDefaults):
758
+ """
759
+ Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
760
+ due to only updating one column of the inverse hessian approximation per step.
761
+
762
+ .. note::
763
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
764
+
765
+ .. warning::
766
+ this uses roughly O(N^2) memory.
767
+
768
+ Reference:
769
+ Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
770
+ """
771
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
772
+ return icum_H_(H=H, s=s, y=y)
295
773
 
296
- def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
774
+ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
297
775
  s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
298
776
  I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
299
777
  d = (R + I * (s_norm/2)) @ s
300
- ds = d.dot(s)
301
- if ds.abs() <= tol: return H, R
778
+ ds = _safe_clip(d.dot(s))
302
779
  R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
303
780
 
304
781
  c = H.T @ d
305
- cy = c.dot(y)
306
- if cy.abs() <= tol: return H, R
782
+ cy = _safe_clip(c.dot(y))
307
783
  num = (H@y).sub_(s).outer(c)
308
784
  H -= num/cy
309
785
  return H, R
310
786
 
311
- class ThomasOptimalMethod(HUpdateStrategy):
312
- """Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
313
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
787
+ class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
788
+ """
789
+ Thomas's "optimal" Quasi-Newton method.
790
+
791
+ .. note::
792
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
793
+
794
+ .. note::
795
+ BFGS is the recommended QN method and will usually outperform this.
796
+
797
+ .. warning::
798
+ this uses roughly O(N^2) memory.
799
+
800
+ Reference:
801
+ Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
802
+ """
803
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
314
804
  if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
315
- H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
805
+ H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
316
806
  return H
317
807
 
318
808
  def _reset_M_(self, M, s, y,inverse, init_scale, state):
@@ -321,97 +811,120 @@ class ThomasOptimalMethod(HUpdateStrategy):
321
811
  st.pop("R", None)
322
812
 
323
813
  # ------------------------ powell's symmetric broyden ------------------------ #
324
- def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
814
+ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
325
815
  y_Bs = y - B@s
326
- ss = s.dot(s)
327
- if ss.abs() < tol: return B
816
+ ss = _safe_clip(s.dot(s))
328
817
  num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
329
818
  term1 = num1.div_(ss)
330
- term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
819
+ term2 = s.outer(s).mul_(y_Bs.dot(s)/(_safe_clip(ss**2)))
331
820
  B += term1.sub_(term2)
332
821
  return B
333
822
 
334
823
  # I couldn't find formula for H
335
- class PSB(HessianUpdateStrategy):
336
- def __init__(
337
- self,
338
- init_scale: float | Literal["auto"] = 'auto',
339
- tol: float = 1e-10,
340
- tol_reset: bool = True,
341
- reset_interval: int | None = None,
342
- beta: float | None = None,
343
- update_freq: int = 1,
344
- scale_first: bool = True,
345
- scale_second: bool = False,
346
- concat_params: bool = True,
347
- inner: Chainable | None = None,
348
- ):
349
- super().__init__(
350
- defaults=None,
351
- init_scale=init_scale,
352
- tol=tol,
353
- tol_reset=tol_reset,
354
- reset_interval=reset_interval,
355
- beta=beta,
356
- update_freq=update_freq,
357
- scale_first=scale_first,
358
- scale_second=scale_second,
359
- concat_params=concat_params,
360
- inverse=False,
361
- inner=inner,
362
- )
824
+ class PSB(_HessianUpdateStrategyDefaults):
825
+ """Powell's Symmetric Broyden Quasi-Newton method.
826
+
827
+ .. note::
828
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
829
+
830
+ .. note::
831
+ BFGS is the recommended QN method and will usually outperform this.
363
832
 
364
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
365
- return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
833
+ .. warning::
834
+ this uses roughly O(N^2) memory.
835
+
836
+ Reference:
837
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
838
+ """
839
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
840
+ return psb_B_(B=B, s=s, y=y)
366
841
 
367
842
 
368
843
  # Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
369
- def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
844
+ def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
370
845
  Hy = H@y
371
- yHy = y.dot(Hy)
372
- if yHy.abs() <= tol: return H
846
+ yHy = _safe_clip(y.dot(Hy))
373
847
  num = (s - Hy).outer(Hy)
374
848
  H += num.div_(yHy)
375
849
  return H
376
850
 
377
- class Pearson(HUpdateStrategy):
378
- """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
851
+ class Pearson(_InverseHessianUpdateStrategyDefaults):
852
+ """
853
+ Pearson's Quasi-Newton method.
379
854
 
380
- This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
381
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
382
- return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
855
+ .. note::
856
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
383
857
 
384
- def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
385
- sy = s.dot(y)
386
- if sy.abs() <= tol: return H
858
+ .. note::
859
+ BFGS is the recommended QN method and will usually outperform this.
860
+
861
+ .. warning::
862
+ this uses roughly O(N^2) memory.
863
+
864
+ Reference:
865
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
866
+ """
867
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
868
+ return pearson_H_(H=H, s=s, y=y)
869
+
870
+ def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
871
+ sy = _safe_clip(s.dot(y))
387
872
  num = (s - H@y).outer(s)
388
873
  H += num.div_(sy)
389
874
  return H
390
875
 
391
- class McCormick(HUpdateStrategy):
392
- """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
876
+ class McCormick(_InverseHessianUpdateStrategyDefaults):
877
+ """McCormicks's Quasi-Newton method.
878
+
879
+ .. note::
880
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
881
+
882
+ .. note::
883
+ BFGS is the recommended QN method and will usually outperform this.
393
884
 
394
- This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
395
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
396
- return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
885
+ .. warning::
886
+ this uses roughly O(N^2) memory.
397
887
 
398
- def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
888
+ Reference:
889
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
890
+
891
+ This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
892
+ """
893
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
894
+ return mccormick_H_(H=H, s=s, y=y)
895
+
896
+ def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
399
897
  Hy = H @ y
400
- yHy = y.dot(Hy)
401
- if yHy.abs() < tol: return H, R
898
+ yHy = _safe_clip(y.dot(Hy))
402
899
  H -= Hy.outer(Hy) / yHy
403
900
  R += (s - R@y).outer(Hy) / yHy
404
901
  return H, R
405
902
 
406
903
  class ProjectedNewtonRaphson(HessianUpdateStrategy):
407
- """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
904
+ """
905
+ Projected Newton Raphson method.
906
+
907
+ .. note::
908
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
909
+
910
+ .. note::
911
+ this is an experimental method.
408
912
 
409
- Algorithm 7"""
913
+ .. warning::
914
+ this uses roughly O(N^2) memory.
915
+
916
+ Reference:
917
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
918
+
919
+ This one is Algorithm 7.
920
+ """
410
921
  def __init__(
411
922
  self,
412
923
  init_scale: float | Literal["auto"] = 'auto',
413
- tol: float = 1e-10,
414
- tol_reset: bool = True,
924
+ tol: float = 1e-8,
925
+ ptol: float | None = 1e-10,
926
+ ptol_reset: bool = False,
927
+ gtol: float | None = 1e-10,
415
928
  reset_interval: int | None | Literal['auto'] = 'auto',
416
929
  beta: float | None = None,
417
930
  update_freq: int = 1,
@@ -423,7 +936,9 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
423
936
  super().__init__(
424
937
  init_scale=init_scale,
425
938
  tol=tol,
426
- tol_reset=tol_reset,
939
+ ptol = ptol,
940
+ ptol_reset=ptol_reset,
941
+ gtol=gtol,
427
942
  reset_interval=reset_interval,
428
943
  beta=beta,
429
944
  update_freq=update_freq,
@@ -434,9 +949,9 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
434
949
  inner=inner,
435
950
  )
436
951
 
437
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
952
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
438
953
  if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
439
- H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
954
+ H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
440
955
  state["R"] = R
441
956
  return H
442
957
 
@@ -454,12 +969,10 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
454
969
  # however p.12 says eps = gs / gHy
455
970
 
456
971
  Hy = H@y
457
- gHy = g.dot(Hy)
458
- yHy = y.dot(Hy)
972
+ gHy = _safe_clip(g.dot(Hy))
973
+ yHy = _safe_clip(y.dot(Hy))
459
974
  sy = s.dot(y)
460
- if sy < tol: return H
461
- if yHy.abs() < tol: return H
462
- if gHy.abs() < tol: return H
975
+ if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
463
976
 
464
977
  v_mul = yHy.sqrt()
465
978
  v_term1 = s/sy
@@ -474,28 +987,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
474
987
  e = gs / gHy
475
988
  if switch in (1, 3):
476
989
  if e/o <= 1:
477
- if o.abs() <= tol: return H
478
- phi = e/o
990
+ phi = e/_safe_clip(o)
479
991
  theta = 0
480
992
  elif o/t >= 1:
481
- if t.abs() <= tol: return H
482
- phi = o/t
993
+ phi = o/_safe_clip(t)
483
994
  theta = 1
484
995
  else:
485
996
  phi = 1
486
- denom = e*t - o**2
487
- if denom.abs() <= tol: return H
997
+ denom = _safe_clip(e*t - o**2)
488
998
  if switch == 1: theta = o * (e - o) / denom
489
999
  else: theta = o * (t - o) / denom
490
1000
 
491
1001
  elif switch == 2:
492
- if t.abs() <= tol or o.abs() <= tol or e.abs() <= tol: return H
1002
+ t = _safe_clip(t)
1003
+ o = _safe_clip(o)
1004
+ e = _safe_clip(e)
493
1005
  phi = (e / t) ** 0.5
494
1006
  theta = 1 / (1 + (t*e / o**2)**0.5)
495
1007
 
496
1008
  elif switch == 4:
497
- if t.abs() <= tol: return H
498
- phi = e/t
1009
+ phi = e/_safe_clip(t)
499
1010
  theta = 1/2
500
1011
 
501
1012
  else: raise ValueError(switch)
@@ -514,14 +1025,29 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
514
1025
 
515
1026
 
516
1027
  class SSVM(HessianUpdateStrategy):
517
- """This one is from Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
1028
+ """
1029
+ Self-scaling variable metric Quasi-Newton method.
1030
+
1031
+ .. note::
1032
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1033
+
1034
+ .. note::
1035
+ BFGS is the recommended QN method and will usually outperform this.
1036
+
1037
+ .. warning::
1038
+ this uses roughly O(N^2) memory.
1039
+
1040
+ Reference:
1041
+ Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
518
1042
  """
519
1043
  def __init__(
520
1044
  self,
521
1045
  switch: tuple[float,float] | Literal[1,2,3,4] = 3,
522
1046
  init_scale: float | Literal["auto"] = 'auto',
523
- tol: float = 1e-10,
524
- tol_reset: bool = True,
1047
+ tol: float = 1e-8,
1048
+ ptol: float | None = 1e-10,
1049
+ ptol_reset: bool = False,
1050
+ gtol: float | None = 1e-10,
525
1051
  reset_interval: int | None = None,
526
1052
  beta: float | None = None,
527
1053
  update_freq: int = 1,
@@ -535,7 +1061,9 @@ class SSVM(HessianUpdateStrategy):
535
1061
  defaults=defaults,
536
1062
  init_scale=init_scale,
537
1063
  tol=tol,
538
- tol_reset=tol_reset,
1064
+ ptol=ptol,
1065
+ ptol_reset=ptol_reset,
1066
+ gtol=gtol,
539
1067
  reset_interval=reset_interval,
540
1068
  beta=beta,
541
1069
  update_freq=update_freq,
@@ -546,17 +1074,16 @@ class SSVM(HessianUpdateStrategy):
546
1074
  inner=inner,
547
1075
  )
548
1076
 
549
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
550
- return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
1077
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1078
+ return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])
551
1079
 
552
1080
  # HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
553
1081
  def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
554
1082
  Hy = H@y
555
1083
  ys = y.dot(s)
556
- if ys.abs() <= tol: return H
1084
+ if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
557
1085
  yHy = y.dot(Hy)
558
- denom = ys + yHy
559
- if denom.abs() <= tol: return H
1086
+ denom = _safe_clip(ys + yHy)
560
1087
 
561
1088
  term1 = 1/denom
562
1089
  term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
@@ -569,19 +1096,35 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
569
1096
  return H
570
1097
 
571
1098
  def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
572
- sy = s.dot(y)
573
- if sy.abs() < torch.finfo(g[0].dtype).eps: return g
1099
+ sy = _safe_clip(s.dot(y))
574
1100
  return g - (y * (s.dot(g) / sy))
575
1101
 
576
1102
 
577
1103
  class GradientCorrection(Transform):
578
- """estimates gradient at minima along search direction assuming function is quadratic as proposed in HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1104
+ """
1105
+ Estimates gradient at minima along search direction assuming function is quadratic.
1106
+
1107
+ This can useful as inner module for second order methods with inexact line search.
1108
+
1109
+ Example:
1110
+ L-BFGS with gradient correction
1111
+
1112
+ .. code-block :: python
1113
+
1114
+ opt = tz.Modular(
1115
+ model.parameters(),
1116
+ tz.m.LBFGS(inner=tz.m.GradientCorrection()),
1117
+ tz.m.Backtracking()
1118
+ )
1119
+
1120
+ Reference:
1121
+ HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
579
1122
 
580
- This can useful as inner module for second order methods."""
1123
+ """
581
1124
  def __init__(self):
582
1125
  super().__init__(None, uses_grad=False)
583
1126
 
584
- def apply(self, tensors, params, grads, loss, states, settings):
1127
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
585
1128
  if 'p_prev' not in states[0]:
586
1129
  p_prev = unpack_states(states, tensors, 'p_prev', init=params)
587
1130
  g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
@@ -594,15 +1137,30 @@ class GradientCorrection(Transform):
594
1137
  g_prev.copy_(tensors)
595
1138
  return g_hat
596
1139
 
597
- class Horisho(HUpdateStrategy):
598
- """HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394"""
599
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
600
- return hoshino_H_(H=H, s=s, y=y, tol=settings['tol'])
1140
+ class Horisho(_InverseHessianUpdateStrategyDefaults):
1141
+ """
1142
+ Horisho's variable metric Quasi-Newton method.
1143
+
1144
+ .. note::
1145
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1146
+
1147
+ .. note::
1148
+ BFGS is the recommended QN method and will usually outperform this.
1149
+
1150
+ .. warning::
1151
+ this uses roughly O(N^2) memory.
1152
+
1153
+ Reference:
1154
+ HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1155
+ """
1156
+
1157
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1158
+ return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])
601
1159
 
602
1160
  # Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
603
1161
  def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
604
1162
  sy = s.dot(y)
605
- if sy.abs() < tol: return H
1163
+ if sy.abs() < tol: return H # part of algorithm
606
1164
  Hy = H @ y
607
1165
 
608
1166
  term1 = (s.outer(y) @ H).div_(sy)
@@ -613,16 +1171,30 @@ def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float)
613
1171
  H -= (term1 + term2 - term4.mul_(term3))
614
1172
  return H
615
1173
 
616
- class FletcherVMM(HUpdateStrategy):
617
- """Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317"""
618
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
619
- return fletcher_vmm_H_(H=H, s=s, y=y, tol=settings['tol'])
1174
+ class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
1175
+ """
1176
+ Fletcher's variable metric Quasi-Newton method.
1177
+
1178
+ .. note::
1179
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1180
+
1181
+ .. note::
1182
+ BFGS is the recommended QN method and will usually outperform this.
1183
+
1184
+ .. warning::
1185
+ this uses roughly O(N^2) memory.
1186
+
1187
+ Reference:
1188
+ Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
1189
+ """
1190
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1191
+ return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])
620
1192
 
621
1193
 
622
1194
  # Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
623
1195
  def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
624
1196
  sy = s.dot(y)
625
- if sy < tol: return H
1197
+ if sy < tol: return H # part of algorithm
626
1198
 
627
1199
  term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
628
1200
 
@@ -644,15 +1216,25 @@ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol:
644
1216
 
645
1217
 
646
1218
  class NewSSM(HessianUpdateStrategy):
647
- """Self-scaling method, requires a line search.
1219
+ """Self-scaling Quasi-Newton method.
1220
+
1221
+ .. note::
1222
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
648
1223
 
649
- Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U."""
1224
+ .. warning::
1225
+ this uses roughly O(N^2) memory.
1226
+
1227
+ Reference:
1228
+ Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
1229
+ """
650
1230
  def __init__(
651
1231
  self,
652
1232
  type: Literal[1, 2] = 1,
653
1233
  init_scale: float | Literal["auto"] = "auto",
654
- tol: float = 1e-10,
655
- tol_reset: bool = True,
1234
+ tol: float = 1e-8,
1235
+ ptol: float | None = 1e-10,
1236
+ ptol_reset: bool = False,
1237
+ gtol: float | None = 1e-10,
656
1238
  reset_interval: int | None = None,
657
1239
  beta: float | None = None,
658
1240
  update_freq: int = 1,
@@ -665,7 +1247,9 @@ class NewSSM(HessianUpdateStrategy):
665
1247
  defaults=dict(type=type),
666
1248
  init_scale=init_scale,
667
1249
  tol=tol,
668
- tol_reset=tol_reset,
1250
+ ptol=ptol,
1251
+ ptol_reset=ptol_reset,
1252
+ gtol=gtol,
669
1253
  reset_interval=reset_interval,
670
1254
  beta=beta,
671
1255
  update_freq=update_freq,
@@ -675,9 +1259,73 @@ class NewSSM(HessianUpdateStrategy):
675
1259
  inverse=True,
676
1260
  inner=inner,
677
1261
  )
678
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
1262
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
679
1263
  f = state['f']
680
1264
  f_prev = state['f_prev']
681
- return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=settings['type'], tol=settings['tol'])
1265
+ return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])
1266
+
1267
+ # ---------------------------- Shor’s r-algorithm ---------------------------- #
1268
+ # def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
1269
+ # r = B.T @ y
1270
+ # r /= torch.linalg.vector_norm(r).clip(min=1e-8) # pylint:disable=not-callable
1271
+
1272
+ # I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
1273
+ # return B @ (I - gamma*r.outer(r))
1274
+
1275
+ # this is supposed to be equivalent
1276
+ def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
1277
+ p = H@y
1278
+ #(1-y)^2 (ppT)/(pTq)
1279
+ term = p.outer(p).div_(p.dot(y).clip(min=1e-8))
1280
+ H.sub_(term, alpha=1-alpha**2)
1281
+ return H
1282
+
1283
+ class ShorR(HessianUpdateStrategy):
1284
+ """Shor’s r-algorithm.
1285
+
1286
+ .. note::
1287
+ a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
1288
+
1289
+ Reference:
1290
+ Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.
1291
+
1292
+ Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
1293
+ """
682
1294
 
1295
+ def __init__(
1296
+ self,
1297
+ alpha=0.5,
1298
+ init_scale: float | Literal["auto"] = 1,
1299
+ tol: float = 1e-8,
1300
+ ptol: float | None = 1e-10,
1301
+ ptol_reset: bool = False,
1302
+ gtol: float | None = 1e-10,
1303
+ reset_interval: int | None | Literal['auto'] = None,
1304
+ beta: float | None = None,
1305
+ update_freq: int = 1,
1306
+ scale_first: bool = False,
1307
+ scale_second: bool = False,
1308
+ concat_params: bool = True,
1309
+ # inverse: bool = True,
1310
+ inner: Chainable | None = None,
1311
+ ):
1312
+ defaults = dict(alpha=alpha)
1313
+ super().__init__(
1314
+ defaults=defaults,
1315
+ init_scale=init_scale,
1316
+ tol=tol,
1317
+ ptol=ptol,
1318
+ ptol_reset=ptol_reset,
1319
+ gtol=gtol,
1320
+ reset_interval=reset_interval,
1321
+ beta=beta,
1322
+ update_freq=update_freq,
1323
+ scale_first=scale_first,
1324
+ scale_second=scale_second,
1325
+ concat_params=concat_params,
1326
+ inverse=True,
1327
+ inner=inner,
1328
+ )
683
1329
 
1330
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1331
+ return shor_r_(H=H, y=y, alpha=setting['alpha'])