torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,18 +1,16 @@
1
- """Use BFGS or maybe SR1."""
1
+ import warnings
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import Mapping
3
+ from collections.abc import Callable, Mapping
4
4
  from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
8
  from ...core import Chainable, Module, TensorwiseTransform, Transform
9
- from ...utils import TensorList, set_storage_, unpack_states
9
+ from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
10
+ from ...utils.linalg import linear_operator
11
+ from ..functional import initial_step_size, safe_clip
10
12
 
11
13
 
12
- def _safe_dict_update_(d1_:dict, d2:dict):
13
- inter = set(d1_.keys()).intersection(d2.keys())
14
- if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
15
- d1_.update(d2)
16
14
 
17
15
  def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
18
16
  if (beta is None) or (beta == 0) or (key not in state): state[key] = value
@@ -20,68 +18,165 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
20
18
  else: state[key].lerp_(value, 1-beta)
21
19
 
22
20
  class HessianUpdateStrategy(TensorwiseTransform, ABC):
21
+ """Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
22
+
23
+ This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
24
+ and if necessary, ``initialize_P``, ``modify_H`` and ``modify_B``.
25
+
26
+ Args:
27
+ defaults (dict | None, optional): defaults. Defaults to None.
28
+ init_scale (float | Literal["auto"], optional):
29
+ initial hessian matrix is set to identity times this.
30
+
31
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
32
+
33
+ Defaults to "auto".
34
+ tol (float, optional):
35
+ algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-32.
36
+ ptol (float | None, optional):
37
+ tolerance for minimal parameter difference to avoid instability. Defaults to 1e-32.
38
+ ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
39
+ gtol (float | None, optional):
40
+ tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-32.
41
+ restart_interval (int | None | Literal["auto"], optional):
42
+ interval between resetting the hessian approximation.
43
+
44
+ "auto" corresponds to number of decision variables + 1.
45
+
46
+ None - no resets.
47
+
48
+ Defaults to None.
49
+ beta (float | None, optional): momentum on H or B. Defaults to None.
50
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
51
+ scale_first (bool, optional):
52
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
53
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
54
+ concat_params (bool, optional):
55
+ If true, all parameters are treated as a single vector.
56
+ If False, the update rule is applied to each parameter separately. Defaults to True.
57
+ inverse (bool, optional):
58
+ set to True if this method uses hessian inverse approximation H and has `update_H` method.
59
+ set to False if this maintains hessian approximation B and has `update_B method`.
60
+ Defaults to True.
61
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
62
+
63
+ ## Notes
64
+
65
+ ### update
66
+
67
+ On 1st ``update_tensor`` H or B is initialized using ``initialize_P``, which returns identity matrix by default.
68
+
69
+ 2nd and subsequent ``update_tensor`` calls ``update_H`` or ``update_B``.
70
+
71
+ Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
72
+
73
+ ### apply
74
+
75
+ ``apply_tensor`` computes ``H = modify_H(H)`` or ``B = modify_B(B)``, those methods do nothing by default.
76
+
77
+ Then it computes and returns ``H @ input`` or ``solve(B, input)``.
78
+
79
+ Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
80
+
81
+ ### initial scale
82
+
83
+ If ``init_scale`` is a scalar, the preconditioner is multiplied or divided (if inverse) by it on first ``update_tensor``.
84
+
85
+ If ``init_scale="auto"``, it is computed and applied on the second ``update_tensor``.
86
+
87
+ ### get_H
88
+
89
+ First it computes ``H = modify_H(H)`` or ``B = modify_B(B)``.
90
+
91
+ Returns a ``Dense`` linear operator with ``B``, or ``DenseInverse`` linear operator with ``H``.
92
+
93
+ But if H/B has 1 dimension, ``Diagonal`` linear operator is returned with ``B`` or ``1/H``.
94
+ """
23
95
  def __init__(
24
96
  self,
25
97
  defaults: dict | None = None,
26
98
  init_scale: float | Literal["auto"] = "auto",
27
- tol: float = 1e-10,
28
- tol_reset: bool = True,
29
- reset_interval: int | None | Literal['auto'] = None,
99
+ tol: float = 1e-32,
100
+ ptol: float | None = 1e-32,
101
+ ptol_restart: bool = False,
102
+ gtol: float | None = 1e-32,
103
+ restart_interval: int | None | Literal['auto'] = None,
30
104
  beta: float | None = None,
31
105
  update_freq: int = 1,
32
- scale_first: bool = True,
33
- scale_second: bool = False,
106
+ scale_first: bool = False,
34
107
  concat_params: bool = True,
35
108
  inverse: bool = True,
36
109
  inner: Chainable | None = None,
37
110
  ):
38
111
  if defaults is None: defaults = {}
39
- _safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, tol_reset=tol_reset, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
40
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
112
+ safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
113
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
41
114
 
42
- def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
43
- """returns multiplier to H or B"""
44
- ys = y.dot(s)
45
- yy = y.dot(y)
46
- if ys != 0 and yy != 0: return yy/ys
47
- return 1
115
+ def reset_for_online(self):
116
+ super().reset_for_online()
117
+ self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
48
118
 
49
- def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
50
- set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
51
- if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
52
- if init_scale >= 1:
53
- if inverse: M /= init_scale
54
- else: M *= init_scale
119
+ # ---------------------------- methods to override --------------------------- #
120
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool) -> torch.Tensor:
121
+ """returns the initial torch.Tensor for H or B"""
122
+ return torch.eye(size, device=device, dtype=dtype)
55
123
 
56
124
  def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
57
- p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
125
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
58
126
  """update hessian inverse"""
59
- raise NotImplementedError
127
+ raise NotImplementedError(f"hessian inverse approximation is not implemented for {self.__class__.__name__}.")
60
128
 
61
129
  def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
62
- p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
130
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
63
131
  """update hessian"""
64
- raise NotImplementedError
132
+ raise NotImplementedError(f"{self.__class__.__name__} only supports hessian inverse approximation. "
133
+ "Remove the `inverse=False` argument when initializing this module.")
134
+
135
+ def modify_B(self, B: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
136
+ """modifies B out of place before appling the update rule, doesn't affect the buffer B."""
137
+ return B
138
+
139
+ def modify_H(self, H: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
140
+ """modifies H out of place before appling the update rule, doesn't affect the buffer H."""
141
+ return H
142
+
143
+ # ------------------------------ common methods ------------------------------ #
144
+ def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
145
+ """returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
146
+ ys = y.dot(s)
147
+ yy = y.dot(y)
148
+ if ys != 0 and yy != 0: return yy/ys
149
+ return 1
150
+
151
+ def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
152
+ """resets ``P`` which is either B or H"""
153
+ set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
154
+ if init_scale == 'auto': init_scale = self.auto_initial_scale(s,y)
155
+ if init_scale >= 1:
156
+ if inverse: P /= init_scale
157
+ else: P *= init_scale
65
158
 
66
159
  @torch.no_grad
67
- def update_tensor(self, tensor, param, grad, loss, state, settings):
160
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
68
161
  p = param.view(-1); g = tensor.view(-1)
69
- inverse = settings['inverse']
162
+ inverse = setting['inverse']
70
163
  M_key = 'H' if inverse else 'B'
71
164
  M = state.get(M_key, None)
72
- step = state.get('step', 0)
73
- state['step'] = step + 1
74
- init_scale = settings['init_scale']
75
- tol = settings['tol']
76
- tol_reset = settings['tol_reset']
77
- reset_interval = settings['reset_interval']
78
- if reset_interval == 'auto': reset_interval = tensor.numel() + 1
79
-
80
- if M is None:
81
- M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
82
- if isinstance(init_scale, (int, float)) and init_scale != 1:
83
- if inverse: M /= init_scale
84
- else: M *= init_scale
165
+ step = state.get('step', 0) + 1
166
+ state['step'] = step
167
+ init_scale = setting['init_scale']
168
+ ptol = setting['ptol']
169
+ ptol_restart = setting['ptol_restart']
170
+ gtol = setting['gtol']
171
+ restart_interval = setting['restart_interval']
172
+ if restart_interval == 'auto': restart_interval = tensor.numel() + 1
173
+
174
+ if M is None or 'f_prev' not in state:
175
+ if M is None: # won't be true on reset_for_online
176
+ M = self.initialize_P(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
177
+ if isinstance(init_scale, (int, float)) and init_scale != 1:
178
+ if inverse: M /= init_scale
179
+ else: M *= init_scale
85
180
 
86
181
  state[M_key] = M
87
182
  state['f_prev'] = loss
@@ -97,190 +192,487 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
97
192
  state['p_prev'].copy_(p)
98
193
  state['g_prev'].copy_(g)
99
194
 
100
- if reset_interval is not None and step != 0 and step % reset_interval == 0:
101
- self._reset_M_(M, s, y, inverse, init_scale, state)
195
+ if restart_interval is not None and step % restart_interval == 0:
196
+ self.reset_P(M, s, y, inverse, init_scale, state)
197
+ return
198
+
199
+ # tolerance on parameter difference to avoid exploding after converging
200
+ if ptol is not None and s.abs().max() <= ptol:
201
+ if ptol_restart: self.reset_P(M, s, y, inverse, init_scale, state) # reset history
102
202
  return
103
203
 
104
- # tolerance on gradient difference to avoid exploding after converging
105
- if y.abs().max() <= tol:
106
- # reset history
107
- if tol_reset: self._reset_M_(M, s, y, inverse, init_scale, state)
204
+ # tolerance on gradient difference to avoid exploding when there is no curvature
205
+ if gtol is not None and y.abs().max() <= gtol:
108
206
  return
109
207
 
110
- if step == 1 and init_scale == 'auto':
111
- if inverse: M /= self._get_init_scale(s,y)
112
- else: M *= self._get_init_scale(s,y)
208
+ if step == 2 and init_scale == 'auto':
209
+ if inverse: M /= self.auto_initial_scale(s,y)
210
+ else: M *= self.auto_initial_scale(s,y)
113
211
 
114
- beta = settings['beta']
212
+ beta = setting['beta']
115
213
  if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
116
214
 
117
215
  if inverse:
118
- H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
216
+ H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
119
217
  _maybe_lerp_(state, 'H', H_new, beta)
120
218
 
121
219
  else:
122
- B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
220
+ B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
123
221
  _maybe_lerp_(state, 'B', B_new, beta)
124
222
 
125
223
  state['f_prev'] = loss
126
224
 
127
225
  @torch.no_grad
128
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
129
- step = state.get('step', 0)
226
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
227
+ step = state['step']
228
+
229
+ if setting['scale_first'] and step == 1:
230
+ tensor *= initial_step_size(tensor)
130
231
 
131
- if settings['scale_second'] and step == 2:
132
- scale_factor = 1 / tensor.abs().sum().clip(min=1)
133
- scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
134
- tensor = tensor * scale_factor
232
+ inverse = setting['inverse']
233
+ g = tensor.view(-1)
135
234
 
136
- inverse = settings['inverse']
137
235
  if inverse:
138
236
  H = state['H']
139
- return (H @ tensor.view(-1)).view_as(tensor)
237
+ H = self.modify_H(H, state, setting)
238
+ if H.ndim == 1: return g.mul_(H).view_as(tensor)
239
+ return (H @ g).view_as(tensor)
140
240
 
141
241
  B = state['B']
242
+ B = self.modify_B(B, state, setting)
243
+
244
+ if B.ndim == 1: return g.div_(B).view_as(tensor)
245
+ x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
246
+ if info == 0: return x.view_as(tensor)
247
+
248
+ # failed to solve linear system, so reset state
249
+ self.state.clear()
250
+ self.global_state.clear()
251
+ return tensor.mul_(initial_step_size(tensor))
252
+
253
+ def get_H(self, var):
254
+ param = var.params[0]
255
+ state = self.state[param]
256
+ settings = self.settings[param]
257
+ if "B" in state:
258
+ B = self.modify_B(state["B"], state, settings)
259
+ if B.ndim == 2: return linear_operator.Dense(B)
260
+ assert B.ndim == 1, B.shape
261
+ return linear_operator.Diagonal(B)
262
+
263
+ if "H" in state:
264
+ H = self.modify_H(state["H"], state, settings)
265
+ if H.ndim != 1: return linear_operator.DenseInverse(H)
266
+ return linear_operator.Diagonal(1/H)
267
+
268
+ return None
269
+
270
+ class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
271
+ '''This is ``HessianUpdateStrategy`` subclass for algorithms with no extra defaults, to skip the lengthy ``__init__``.
272
+ Refer to ``HessianUpdateStrategy`` documentation.
273
+
274
+ ## Example:
275
+
276
+ Implementing BFGS method that maintains an estimate of the hessian inverse (H):
277
+ ```python
278
+ class BFGS(_HessianUpdateStrategyDefaults):
279
+ """Broyden–Fletcher–Goldfarb–Shanno algorithm"""
280
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
281
+ tol = settings["tol"]
282
+ sy = torch.dot(s, y)
283
+ if sy <= tol: return H
284
+ num1 = (sy + (y @ H @ y)) * s.outer(s)
285
+ term1 = num1.div_(sy**2)
286
+ num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
287
+ term2 = num2.div_(sy)
288
+ H += term1.sub_(term2)
289
+ return H
290
+ ```
291
+
292
+ Make sure to put at least a basic class level docstring to overwrite this.
293
+ '''
294
+ def __init__(
295
+ self,
296
+ init_scale: float | Literal["auto"] = "auto",
297
+ tol: float = 1e-32,
298
+ ptol: float | None = 1e-32,
299
+ ptol_restart: bool = False,
300
+ gtol: float | None = 1e-32,
301
+ restart_interval: int | None = None,
302
+ beta: float | None = None,
303
+ update_freq: int = 1,
304
+ scale_first: bool = False,
305
+ concat_params: bool = True,
306
+ inverse: bool = True,
307
+ inner: Chainable | None = None,
308
+ ):
309
+ super().__init__(
310
+ defaults=None,
311
+ init_scale=init_scale,
312
+ tol=tol,
313
+ ptol=ptol,
314
+ ptol_restart=ptol_restart,
315
+ gtol=gtol,
316
+ restart_interval=restart_interval,
317
+ beta=beta,
318
+ update_freq=update_freq,
319
+ scale_first=scale_first,
320
+ concat_params=concat_params,
321
+ inverse=inverse,
322
+ inner=inner,
323
+ )
142
324
 
143
- return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
144
-
145
- # to avoid typing all arguments for each method
146
- class HUpdateStrategy(HessianUpdateStrategy):
325
+ class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
147
326
  def __init__(
148
327
  self,
149
328
  init_scale: float | Literal["auto"] = "auto",
150
- tol: float = 1e-10,
151
- tol_reset: bool = True,
152
- reset_interval: int | None = None,
329
+ tol: float = 1e-32,
330
+ ptol: float | None = 1e-32,
331
+ ptol_restart: bool = False,
332
+ gtol: float | None = 1e-32,
333
+ restart_interval: int | None = None,
153
334
  beta: float | None = None,
154
335
  update_freq: int = 1,
155
- scale_first: bool = True,
156
- scale_second: bool = False,
336
+ scale_first: bool = False,
157
337
  concat_params: bool = True,
338
+ inverse: bool = False,
158
339
  inner: Chainable | None = None,
159
340
  ):
160
341
  super().__init__(
161
342
  defaults=None,
162
343
  init_scale=init_scale,
163
344
  tol=tol,
164
- tol_reset=tol_reset,
165
- reset_interval=reset_interval,
345
+ ptol=ptol,
346
+ ptol_restart=ptol_restart,
347
+ gtol=gtol,
348
+ restart_interval=restart_interval,
166
349
  beta=beta,
167
350
  update_freq=update_freq,
168
351
  scale_first=scale_first,
169
- scale_second=scale_second,
170
352
  concat_params=concat_params,
171
- inverse=True,
353
+ inverse=inverse,
172
354
  inner=inner,
173
355
  )
356
+
174
357
  # ----------------------------------- BFGS ----------------------------------- #
358
+ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
359
+ sy = s.dot(y)
360
+ if sy < tol: return B
361
+
362
+ Bs = B@s
363
+ sBs = safe_clip(s.dot(Bs))
364
+
365
+ term1 = y.outer(y).div_(sy)
366
+ term2 = (Bs.outer(s) @ B.T).div_(sBs)
367
+ B += term1.sub_(term2)
368
+ return B
369
+
175
370
  def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
176
- sy = torch.dot(s, y)
177
- if sy <= tol: return H # don't reset H in this case
178
- num1 = (sy + (y @ H @ y)) * s.outer(s)
179
- term1 = num1.div_(sy**2)
180
- num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
371
+ sy = s.dot(y)
372
+ if sy <= tol: return H
373
+
374
+ sy_sq = safe_clip(sy**2)
375
+
376
+ Hy = H@y
377
+ scale1 = (sy + y.dot(Hy)) / sy_sq
378
+ term1 = s.outer(s).mul_(scale1)
379
+
380
+ num2 = (Hy.outer(s)).add_(s.outer(y @ H))
181
381
  term2 = num2.div_(sy)
382
+
182
383
  H += term1.sub_(term2)
183
384
  return H
184
385
 
185
- class BFGS(HUpdateStrategy):
186
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
187
- return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
386
+ class BFGS(_InverseHessianUpdateStrategyDefaults):
387
+ """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
388
+
389
+ Note:
390
+ a line search or a trust region is recommended
391
+
392
+ Warning:
393
+ this uses at least O(N^2) memory.
394
+
395
+ Args:
396
+ init_scale (float | Literal["auto"], optional):
397
+ initial hessian matrix is set to identity times this.
398
+
399
+ "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
400
+
401
+ Defaults to "auto".
402
+ tol (float, optional):
403
+ tolerance on curvature condition. Defaults to 1e-32.
404
+ ptol (float | None, optional):
405
+ skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
406
+ Defaults to 1e-32.
407
+ ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
408
+ restart_interval (int | None | Literal["auto"], optional):
409
+ interval between resetting the hessian approximation.
410
+
411
+ "auto" corresponds to number of decision variables + 1.
412
+
413
+ None - no resets.
414
+
415
+ Defaults to None.
416
+ beta (float | None, optional): momentum on H or B. Defaults to None.
417
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
418
+ scale_first (bool, optional):
419
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
420
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
421
+ concat_params (bool, optional):
422
+ If true, all parameters are treated as a single vector.
423
+ If False, the update rule is applied to each parameter separately. Defaults to True.
424
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
425
+
426
+ ## Examples:
427
+
428
+ BFGS with backtracking line search:
429
+
430
+ ```python
431
+ opt = tz.Modular(
432
+ model.parameters(),
433
+ tz.m.BFGS(),
434
+ tz.m.Backtracking()
435
+ )
436
+ ```
437
+
438
+ BFGS with trust region
439
+ ```python
440
+ opt = tz.Modular(
441
+ model.parameters(),
442
+ tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
443
+ )
444
+ ```
445
+ """
446
+
447
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
448
+ return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
449
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
450
+ return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])
188
451
 
189
452
  # ------------------------------------ SR1 ----------------------------------- #
190
- def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
453
+ def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
191
454
  z = s - H@y
192
- denom = torch.dot(z, y)
455
+ denom = z.dot(y)
193
456
 
194
457
  z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
195
458
  y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
196
459
 
197
- if y_norm*z_norm < tol: return H
460
+ # if y_norm*z_norm < tol: return H
198
461
 
199
462
  # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
200
463
  if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
201
- H += torch.outer(z, z).div_(denom)
464
+ H += z.outer(z).div_(safe_clip(denom))
202
465
  return H
203
466
 
204
- class SR1(HUpdateStrategy):
205
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
206
- return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
467
+ class SR1(_InverseHessianUpdateStrategyDefaults):
468
+ """Symmetric Rank 1. This works best with a trust region:
469
+ ```python
470
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))
471
+ ```
472
+
473
+ Args:
474
+ init_scale (float | Literal["auto"], optional):
475
+ initial hessian matrix is set to identity times this.
476
+
477
+ "auto" corresponds to a heuristic from [1] p.142-143.
478
+
479
+ Defaults to "auto".
480
+ tol (float, optional):
481
+ tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.
482
+ ptol (float | None, optional):
483
+ skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
484
+ Defaults to 1e-32.
485
+ ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
486
+ restart_interval (int | None | Literal["auto"], optional):
487
+ interval between resetting the hessian approximation.
488
+
489
+ "auto" corresponds to number of decision variables + 1.
490
+
491
+ None - no resets.
492
+
493
+ Defaults to None.
494
+ beta (float | None, optional): momentum on H or B. Defaults to None.
495
+ update_freq (int, optional): frequency of updating H or B. Defaults to 1.
496
+ scale_first (bool, optional):
497
+ whether to downscale first step before hessian approximation becomes available. Defaults to True.
498
+ scale_second (bool, optional): whether to downscale second step. Defaults to False.
499
+ concat_params (bool, optional):
500
+ If true, all parameters are treated as a single vector.
501
+ If False, the update rule is applied to each parameter separately. Defaults to True.
502
+ inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
503
+
504
+ ### Examples:
505
+
506
+ SR1 with trust region
507
+ ```python
508
+ opt = tz.Modular(
509
+ model.parameters(),
510
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
511
+ )
512
+ ```
513
+
514
+ ### References:
515
+ [1]. Nocedal. Stephen J. Wright. Numerical Optimization
516
+ """
517
+
518
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
519
+ return sr1_(H=H, s=s, y=y, tol=setting['tol'])
520
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
521
+ return sr1_(H=B, s=y, y=s, tol=setting['tol'])
522
+
207
523
 
208
524
  # ------------------------------------ DFP ----------------------------------- #
209
525
  def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
210
- sy = torch.dot(s, y)
526
+ sy = s.dot(y)
211
527
  if sy.abs() <= tol: return H
212
- term1 = torch.outer(s, s).div_(sy)
213
- yHy = torch.dot(y, H @ y) #
214
- if yHy.abs() <= tol: return H
215
- num = H @ torch.outer(y, y) @ H
528
+ term1 = s.outer(s).div_(sy)
529
+
530
+ yHy = safe_clip(y.dot(H @ y))
531
+
532
+ num = (H @ y).outer(y) @ H
216
533
  term2 = num.div_(yHy)
534
+
217
535
  H += term1.sub_(term2)
218
536
  return H
219
537
 
220
- class DFP(HUpdateStrategy):
221
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
222
- return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
538
+ def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
539
+ sy = s.dot(y)
540
+ if sy.abs() <= tol: return B
541
+ I = torch.eye(B.size(0), device=B.device, dtype=B.dtype)
542
+ sub = y.outer(s).div_(sy)
543
+ term1 = I - sub
544
+ term2 = I.sub_(sub.T)
545
+ term3 = y.outer(y).div_(sy)
546
+ B = (term1 @ B @ term2).add_(term3)
547
+ return B
548
+
549
+
550
+ class DFP(_InverseHessianUpdateStrategyDefaults):
551
+ """Davidon–Fletcher–Powell Quasi-Newton method.
552
+
553
+ Note:
554
+ a trust region or an accurate line search is recommended.
555
+
556
+ Warning:
557
+ this uses at least O(N^2) memory.
558
+ """
559
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
560
+ return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
561
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
562
+ return dfp_B(B=B, s=s, y=y, tol=setting['tol'])
223
563
 
224
564
 
225
565
  # formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
226
566
  # H' = H - (Hy - S)c^T / c^T*y
227
567
  # the difference is how `c` is calculated
228
568
 
229
- def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
569
+ def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
230
570
  c = H.T @ s
231
- cy = c.dot(y)
232
- if cy.abs() <= tol: return H
571
+ cy = safe_clip(c.dot(y))
233
572
  num = (H@y).sub_(s).outer(c)
234
573
  H -= num/cy
235
574
  return H
575
+ def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
576
+ r = y - B@s
577
+ ss = safe_clip(s.dot(s))
578
+ B += r.outer(s).div_(ss)
579
+ return B
236
580
 
237
- def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
238
- c = y
239
- cy = c.dot(y)
240
- if cy.abs() <= tol: return H
241
- num = (H@y).sub_(s).outer(c)
242
- H -= num/cy
581
+ def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
582
+ yy = safe_clip(y.dot(y))
583
+ num = (s - (H @ y)).outer(y)
584
+ H += num/yy
243
585
  return H
586
+ def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
587
+ r = y - B@s
588
+ ys = safe_clip(y.dot(s))
589
+ B += r.outer(y).div_(ys)
590
+ return B
244
591
 
245
- def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
592
+ def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
246
593
  c = g_prev
247
- cy = c.dot(y)
248
- if cy.abs() <= tol: return H
594
+ cy = safe_clip(c.dot(y))
249
595
  num = (H@y).sub_(s).outer(c)
250
596
  H -= num/cy
251
597
  return H
252
598
 
253
- def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
599
+ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
254
600
  Hy = H @ y
255
601
  c = H @ Hy # pylint:disable=not-callable
256
- cy = c.dot(y)
257
- if cy.abs() <= tol: return H
602
+ cy = safe_clip(c.dot(y))
258
603
  num = Hy.sub_(s).outer(c)
259
604
  H -= num/cy
260
605
  return H
261
606
 
262
- class BroydenGood(HUpdateStrategy):
263
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
264
- return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
607
+ class BroydenGood(_InverseHessianUpdateStrategyDefaults):
608
+ """Broyden's "good" Quasi-Newton method.
609
+
610
+ Note:
611
+ a trust region or an accurate line search is recommended.
612
+
613
+ Warning:
614
+ this uses at least O(N^2) memory.
615
+
616
+ Reference:
617
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
618
+ """
619
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
620
+ return broyden_good_H_(H=H, s=s, y=y)
621
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
622
+ return broyden_good_B_(B=B, s=s, y=y)
265
623
 
266
- class BroydenBad(HUpdateStrategy):
267
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
268
- return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
624
+ class BroydenBad(_InverseHessianUpdateStrategyDefaults):
625
+ """Broyden's "bad" Quasi-Newton method.
269
626
 
270
- class Greenstadt1(HUpdateStrategy):
271
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
272
- return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
627
+ Note:
628
+ a trust region or an accurate line search is recommended.
273
629
 
274
- class Greenstadt2(HUpdateStrategy):
275
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
276
- return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
630
+ Warning:
631
+ this uses at least O(N^2) memory.
277
632
 
633
+ Reference:
634
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
635
+ """
636
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
637
+ return broyden_bad_H_(H=H, s=s, y=y)
638
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
639
+ return broyden_bad_B_(B=B, s=s, y=y)
640
+
641
+ class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
642
+ """Greenstadt's first Quasi-Newton method.
643
+
644
+ Note:
645
+ a trust region or an accurate line search is recommended.
646
+
647
+ Warning:
648
+ this uses at least O(N^2) memory.
649
+
650
+ Reference:
651
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
652
+ """
653
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
654
+ return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)
278
655
 
279
- def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
656
+ class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
657
+ """Greenstadt's second Quasi-Newton method.
658
+
659
+ Note:
660
+ a line search is recommended.
661
+
662
+ Warning:
663
+ this uses at least O(N^2) memory.
664
+
665
+ Reference:
666
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
667
+ """
668
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
669
+ return greenstadt2_H_(H=H, s=s, y=y)
670
+
671
+
672
+ def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
280
673
  j = y.abs().argmax()
281
674
 
282
- denom = y[j]
283
- if denom.abs() < tol: return H
675
+ denom = safe_clip(y[j])
284
676
 
285
677
  Hy = H @ y.unsqueeze(1)
286
678
  num = s.unsqueeze(1) - Hy
@@ -288,161 +680,194 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
288
680
  H[:, j] += num.squeeze() / denom
289
681
  return H
290
682
 
291
- class ColumnUpdatingMethod(HUpdateStrategy):
292
- """Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
293
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
294
- return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
683
+ class ICUM(_InverseHessianUpdateStrategyDefaults):
684
+ """
685
+ Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
686
+ due to only updating one column of the inverse hessian approximation per step.
687
+
688
+ Note:
689
+ a line search is recommended.
690
+
691
+ Warning:
692
+ this uses at least O(N^2) memory.
693
+
694
+ Reference:
695
+ Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
696
+ """
697
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
698
+ return icum_H_(H=H, s=s, y=y)
295
699
 
296
- def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
700
+ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
297
701
  s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
298
702
  I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
299
703
  d = (R + I * (s_norm/2)) @ s
300
- ds = d.dot(s)
301
- if ds.abs() <= tol: return H, R
704
+ ds = safe_clip(d.dot(s))
302
705
  R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
303
706
 
304
707
  c = H.T @ d
305
- cy = c.dot(y)
306
- if cy.abs() <= tol: return H, R
708
+ cy = safe_clip(c.dot(y))
307
709
  num = (H@y).sub_(s).outer(c)
308
710
  H -= num/cy
309
711
  return H, R
310
712
 
311
- class ThomasOptimalMethod(HUpdateStrategy):
312
- """Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
313
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
713
+ class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
714
+ """
715
+ Thomas's "optimal" Quasi-Newton method.
716
+
717
+ Note:
718
+ a line search is recommended.
719
+
720
+ Warning:
721
+ this uses at least O(N^2) memory.
722
+
723
+ Reference:
724
+ Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
725
+ """
726
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
314
727
  if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
315
- H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
728
+ H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
316
729
  return H
317
730
 
318
- def _reset_M_(self, M, s, y,inverse, init_scale, state):
319
- super()._reset_M_(M, s, y, inverse, init_scale, state)
731
+ def reset_P(self, P, s, y, inverse, init_scale, state):
732
+ super().reset_P(P, s, y, inverse, init_scale, state)
320
733
  for st in self.state.values():
321
734
  st.pop("R", None)
322
735
 
323
736
  # ------------------------ powell's symmetric broyden ------------------------ #
324
- def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
737
+ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
325
738
  y_Bs = y - B@s
326
- ss = s.dot(s)
327
- if ss.abs() < tol: return B
739
+ ss = safe_clip(s.dot(s))
328
740
  num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
329
741
  term1 = num1.div_(ss)
330
- term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
742
+ term2 = s.outer(s).mul_(y_Bs.dot(s)/(safe_clip(ss**2)))
331
743
  B += term1.sub_(term2)
332
744
  return B
333
745
 
334
746
  # I couldn't find formula for H
335
- class PSB(HessianUpdateStrategy):
336
- def __init__(
337
- self,
338
- init_scale: float | Literal["auto"] = 'auto',
339
- tol: float = 1e-10,
340
- tol_reset: bool = True,
341
- reset_interval: int | None = None,
342
- beta: float | None = None,
343
- update_freq: int = 1,
344
- scale_first: bool = True,
345
- scale_second: bool = False,
346
- concat_params: bool = True,
347
- inner: Chainable | None = None,
348
- ):
349
- super().__init__(
350
- defaults=None,
351
- init_scale=init_scale,
352
- tol=tol,
353
- tol_reset=tol_reset,
354
- reset_interval=reset_interval,
355
- beta=beta,
356
- update_freq=update_freq,
357
- scale_first=scale_first,
358
- scale_second=scale_second,
359
- concat_params=concat_params,
360
- inverse=False,
361
- inner=inner,
362
- )
747
+ class PSB(_HessianUpdateStrategyDefaults):
748
+ """Powell's Symmetric Broyden Quasi-Newton method.
363
749
 
364
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
365
- return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
750
+ Note:
751
+ a line search or a trust region is recommended.
752
+
753
+ Warning:
754
+ this uses at least O(N^2) memory.
755
+
756
+ Reference:
757
+ Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
758
+ """
759
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
760
+ return psb_B_(B=B, s=s, y=y)
366
761
 
367
762
 
368
763
  # Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
369
- def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
764
+ def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
370
765
  Hy = H@y
371
- yHy = y.dot(Hy)
372
- if yHy.abs() <= tol: return H
766
+ yHy = safe_clip(y.dot(Hy))
373
767
  num = (s - Hy).outer(Hy)
374
768
  H += num.div_(yHy)
375
769
  return H
376
770
 
377
- class Pearson(HUpdateStrategy):
378
- """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
771
+ class Pearson(_InverseHessianUpdateStrategyDefaults):
772
+ """
773
+ Pearson's Quasi-Newton method.
379
774
 
380
- This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
381
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
382
- return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
775
+ Note:
776
+ a line search is recommended.
383
777
 
384
- def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
385
- sy = s.dot(y)
386
- if sy.abs() <= tol: return H
778
+ Warning:
779
+ this uses at least O(N^2) memory.
780
+
781
+ Reference:
782
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
783
+ """
784
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
785
+ return pearson_H_(H=H, s=s, y=y)
786
+
787
+ def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
788
+ sy = safe_clip(s.dot(y))
387
789
  num = (s - H@y).outer(s)
388
790
  H += num.div_(sy)
389
791
  return H
390
792
 
391
- class McCormick(HUpdateStrategy):
392
- """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
793
+ class McCormick(_InverseHessianUpdateStrategyDefaults):
794
+ """McCormicks's Quasi-Newton method.
393
795
 
394
- This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
395
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
396
- return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
796
+ Note:
797
+ a line search is recommended.
397
798
 
398
- def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
799
+ Warning:
800
+ this uses at least O(N^2) memory.
801
+
802
+ Reference:
803
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
804
+
805
+ This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
806
+ """
807
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
808
+ return mccormick_H_(H=H, s=s, y=y)
809
+
810
+ def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
399
811
  Hy = H @ y
400
- yHy = y.dot(Hy)
401
- if yHy.abs() < tol: return H, R
812
+ yHy = safe_clip(y.dot(Hy))
402
813
  H -= Hy.outer(Hy) / yHy
403
814
  R += (s - R@y).outer(Hy) / yHy
404
815
  return H, R
405
816
 
406
817
  class ProjectedNewtonRaphson(HessianUpdateStrategy):
407
- """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
818
+ """
819
+ Projected Newton Raphson method.
820
+
821
+ Note:
822
+ a line search is recommended.
823
+
824
+ Warning:
825
+ this uses at least O(N^2) memory.
826
+
827
+ Reference:
828
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
408
829
 
409
- Algorithm 7"""
830
+ This one is Algorithm 7.
831
+ """
410
832
  def __init__(
411
833
  self,
412
834
  init_scale: float | Literal["auto"] = 'auto',
413
- tol: float = 1e-10,
414
- tol_reset: bool = True,
415
- reset_interval: int | None | Literal['auto'] = 'auto',
835
+ tol: float = 1e-32,
836
+ ptol: float | None = 1e-32,
837
+ ptol_restart: bool = False,
838
+ gtol: float | None = 1e-32,
839
+ restart_interval: int | None | Literal['auto'] = 'auto',
416
840
  beta: float | None = None,
417
841
  update_freq: int = 1,
418
- scale_first: bool = True,
419
- scale_second: bool = False,
842
+ scale_first: bool = False,
420
843
  concat_params: bool = True,
421
844
  inner: Chainable | None = None,
422
845
  ):
423
846
  super().__init__(
424
847
  init_scale=init_scale,
425
848
  tol=tol,
426
- tol_reset=tol_reset,
427
- reset_interval=reset_interval,
849
+ ptol = ptol,
850
+ ptol_restart=ptol_restart,
851
+ gtol=gtol,
852
+ restart_interval=restart_interval,
428
853
  beta=beta,
429
854
  update_freq=update_freq,
430
855
  scale_first=scale_first,
431
- scale_second=scale_second,
432
856
  concat_params=concat_params,
433
857
  inverse=True,
434
858
  inner=inner,
435
859
  )
436
860
 
437
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
861
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
438
862
  if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
439
- H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
863
+ H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
440
864
  state["R"] = R
441
865
  return H
442
866
 
443
- def _reset_M_(self, M, s, y, inverse, init_scale, state):
867
+ def reset_P(self, P, s, y, inverse, init_scale, state):
444
868
  assert inverse
445
- M.copy_(state["R"])
869
+ if 'R' not in state: state['R'] = torch.eye(P.size(-1), device=P.device, dtype=P.dtype)
870
+ P.copy_(state["R"])
446
871
 
447
872
  # Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
448
873
  def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
@@ -454,12 +879,10 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
454
879
  # however p.12 says eps = gs / gHy
455
880
 
456
881
  Hy = H@y
457
- gHy = g.dot(Hy)
458
- yHy = y.dot(Hy)
882
+ gHy = safe_clip(g.dot(Hy))
883
+ yHy = safe_clip(y.dot(Hy))
459
884
  sy = s.dot(y)
460
- if sy < tol: return H
461
- if yHy.abs() < tol: return H
462
- if gHy.abs() < tol: return H
885
+ if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
463
886
 
464
887
  v_mul = yHy.sqrt()
465
888
  v_term1 = s/sy
@@ -474,28 +897,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
474
897
  e = gs / gHy
475
898
  if switch in (1, 3):
476
899
  if e/o <= 1:
477
- if o.abs() <= tol: return H
478
- phi = e/o
900
+ phi = e/safe_clip(o)
479
901
  theta = 0
480
902
  elif o/t >= 1:
481
- if t.abs() <= tol: return H
482
- phi = o/t
903
+ phi = o/safe_clip(t)
483
904
  theta = 1
484
905
  else:
485
906
  phi = 1
486
- denom = e*t - o**2
487
- if denom.abs() <= tol: return H
907
+ denom = safe_clip(e*t - o**2)
488
908
  if switch == 1: theta = o * (e - o) / denom
489
909
  else: theta = o * (t - o) / denom
490
910
 
491
911
  elif switch == 2:
492
- if t.abs() <= tol or o.abs() <= tol or e.abs() <= tol: return H
912
+ t = safe_clip(t)
913
+ o = safe_clip(o)
914
+ e = safe_clip(e)
493
915
  phi = (e / t) ** 0.5
494
916
  theta = 1 / (1 + (t*e / o**2)**0.5)
495
917
 
496
918
  elif switch == 4:
497
- if t.abs() <= tol: return H
498
- phi = e/t
919
+ phi = e/safe_clip(t)
499
920
  theta = 1/2
500
921
 
501
922
  else: raise ValueError(switch)
@@ -514,19 +935,30 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
514
935
 
515
936
 
516
937
  class SSVM(HessianUpdateStrategy):
517
- """This one is from Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
938
+ """
939
+ Self-scaling variable metric Quasi-Newton method.
940
+
941
+ Note:
942
+ a line search is recommended.
943
+
944
+ Warning:
945
+ this uses at least O(N^2) memory.
946
+
947
+ Reference:
948
+ Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
518
949
  """
519
950
  def __init__(
520
951
  self,
521
952
  switch: tuple[float,float] | Literal[1,2,3,4] = 3,
522
953
  init_scale: float | Literal["auto"] = 'auto',
523
- tol: float = 1e-10,
524
- tol_reset: bool = True,
525
- reset_interval: int | None = None,
954
+ tol: float = 1e-32,
955
+ ptol: float | None = 1e-32,
956
+ ptol_restart: bool = False,
957
+ gtol: float | None = 1e-32,
958
+ restart_interval: int | None = None,
526
959
  beta: float | None = None,
527
960
  update_freq: int = 1,
528
- scale_first: bool = True,
529
- scale_second: bool = False,
961
+ scale_first: bool = False,
530
962
  concat_params: bool = True,
531
963
  inner: Chainable | None = None,
532
964
  ):
@@ -535,28 +967,28 @@ class SSVM(HessianUpdateStrategy):
535
967
  defaults=defaults,
536
968
  init_scale=init_scale,
537
969
  tol=tol,
538
- tol_reset=tol_reset,
539
- reset_interval=reset_interval,
970
+ ptol=ptol,
971
+ ptol_restart=ptol_restart,
972
+ gtol=gtol,
973
+ restart_interval=restart_interval,
540
974
  beta=beta,
541
975
  update_freq=update_freq,
542
976
  scale_first=scale_first,
543
- scale_second=scale_second,
544
977
  concat_params=concat_params,
545
978
  inverse=True,
546
979
  inner=inner,
547
980
  )
548
981
 
549
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
550
- return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
982
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
983
+ return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])
551
984
 
552
985
  # HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
553
986
  def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
554
987
  Hy = H@y
555
988
  ys = y.dot(s)
556
- if ys.abs() <= tol: return H
989
+ if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
557
990
  yHy = y.dot(Hy)
558
- denom = ys + yHy
559
- if denom.abs() <= tol: return H
991
+ denom = safe_clip(ys + yHy)
560
992
 
561
993
  term1 = 1/denom
562
994
  term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
@@ -569,19 +1001,35 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
569
1001
  return H
570
1002
 
571
1003
  def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
572
- sy = s.dot(y)
573
- if sy.abs() < torch.finfo(g[0].dtype).eps: return g
1004
+ sy = safe_clip(s.dot(y))
574
1005
  return g - (y * (s.dot(g) / sy))
575
1006
 
576
1007
 
577
1008
  class GradientCorrection(Transform):
578
- """estimates gradient at minima along search direction assuming function is quadratic as proposed in HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1009
+ """
1010
+ Estimates gradient at minima along search direction assuming function is quadratic.
1011
+
1012
+ This can useful as inner module for second order methods with inexact line search.
1013
+
1014
+ ## Example:
1015
+ L-BFGS with gradient correction
579
1016
 
580
- This can useful as inner module for second order methods."""
1017
+ ```python
1018
+ opt = tz.Modular(
1019
+ model.parameters(),
1020
+ tz.m.LBFGS(inner=tz.m.GradientCorrection()),
1021
+ tz.m.Backtracking()
1022
+ )
1023
+ ```
1024
+
1025
+ Reference:
1026
+ HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1027
+
1028
+ """
581
1029
  def __init__(self):
582
1030
  super().__init__(None, uses_grad=False)
583
1031
 
584
- def apply(self, tensors, params, grads, loss, states, settings):
1032
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
585
1033
  if 'p_prev' not in states[0]:
586
1034
  p_prev = unpack_states(states, tensors, 'p_prev', init=params)
587
1035
  g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
@@ -594,15 +1042,27 @@ class GradientCorrection(Transform):
594
1042
  g_prev.copy_(tensors)
595
1043
  return g_hat
596
1044
 
597
- class Horisho(HUpdateStrategy):
598
- """HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394"""
599
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
600
- return hoshino_H_(H=H, s=s, y=y, tol=settings['tol'])
1045
+ class Horisho(_InverseHessianUpdateStrategyDefaults):
1046
+ """
1047
+ Horisho's variable metric Quasi-Newton method.
1048
+
1049
+ Note:
1050
+ a line search is recommended.
1051
+
1052
+ Warning:
1053
+ this uses at least O(N^2) memory.
1054
+
1055
+ Reference:
1056
+ HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
1057
+ """
1058
+
1059
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1060
+ return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])
601
1061
 
602
1062
  # Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
603
1063
  def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
604
1064
  sy = s.dot(y)
605
- if sy.abs() < tol: return H
1065
+ if sy.abs() < tol: return H # part of algorithm
606
1066
  Hy = H @ y
607
1067
 
608
1068
  term1 = (s.outer(y) @ H).div_(sy)
@@ -613,16 +1073,27 @@ def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float)
613
1073
  H -= (term1 + term2 - term4.mul_(term3))
614
1074
  return H
615
1075
 
616
- class FletcherVMM(HUpdateStrategy):
617
- """Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317"""
618
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
619
- return fletcher_vmm_H_(H=H, s=s, y=y, tol=settings['tol'])
1076
+ class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
1077
+ """
1078
+ Fletcher's variable metric Quasi-Newton method.
1079
+
1080
+ Note:
1081
+ a line search is recommended.
1082
+
1083
+ Warning:
1084
+ this uses at least O(N^2) memory.
1085
+
1086
+ Reference:
1087
+ Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
1088
+ """
1089
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1090
+ return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])
620
1091
 
621
1092
 
622
1093
  # Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
623
1094
  def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
624
1095
  sy = s.dot(y)
625
- if sy < tol: return H
1096
+ if sy < tol: return H # part of algorithm
626
1097
 
627
1098
  term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
628
1099
 
@@ -644,20 +1115,29 @@ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol:
644
1115
 
645
1116
 
646
1117
  class NewSSM(HessianUpdateStrategy):
647
- """Self-scaling method, requires a line search.
1118
+ """Self-scaling Quasi-Newton method.
1119
+
1120
+ Note:
1121
+ a line search such as ``tz.m.StrongWolfe()`` is required.
1122
+
1123
+ Warning:
1124
+ this uses roughly O(N^2) memory.
648
1125
 
649
- Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U."""
1126
+ Reference:
1127
+ Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
1128
+ """
650
1129
  def __init__(
651
1130
  self,
652
1131
  type: Literal[1, 2] = 1,
653
1132
  init_scale: float | Literal["auto"] = "auto",
654
- tol: float = 1e-10,
655
- tol_reset: bool = True,
656
- reset_interval: int | None = None,
1133
+ tol: float = 1e-32,
1134
+ ptol: float | None = 1e-32,
1135
+ ptol_restart: bool = False,
1136
+ gtol: float | None = 1e-32,
1137
+ restart_interval: int | None = None,
657
1138
  beta: float | None = None,
658
1139
  update_freq: int = 1,
659
- scale_first: bool = True,
660
- scale_second: bool = False,
1140
+ scale_first: bool = False,
661
1141
  concat_params: bool = True,
662
1142
  inner: Chainable | None = None,
663
1143
  ):
@@ -665,19 +1145,87 @@ class NewSSM(HessianUpdateStrategy):
665
1145
  defaults=dict(type=type),
666
1146
  init_scale=init_scale,
667
1147
  tol=tol,
668
- tol_reset=tol_reset,
669
- reset_interval=reset_interval,
1148
+ ptol=ptol,
1149
+ ptol_restart=ptol_restart,
1150
+ gtol=gtol,
1151
+ restart_interval=restart_interval,
670
1152
  beta=beta,
671
1153
  update_freq=update_freq,
672
1154
  scale_first=scale_first,
673
- scale_second=scale_second,
674
1155
  concat_params=concat_params,
675
1156
  inverse=True,
676
1157
  inner=inner,
677
1158
  )
678
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
1159
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
679
1160
  f = state['f']
680
1161
  f_prev = state['f_prev']
681
- return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=settings['type'], tol=settings['tol'])
1162
+ return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])
1163
+
1164
+ # ---------------------------- Shor’s r-algorithm ---------------------------- #
1165
+ # def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
1166
+ # r = B.T @ y
1167
+ # r /= torch.linalg.vector_norm(r).clip(min=1e-32) # pylint:disable=not-callable
1168
+
1169
+ # I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
1170
+ # return B @ (I - gamma*r.outer(r))
1171
+
1172
+ # this is supposed to be equivalent (and it is)
1173
+ def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
1174
+ p = H@y
1175
+ #(1-y)^2 (ppT)/(pTq)
1176
+ #term = p.outer(p).div_(p.dot(y).clip(min=1e-32))
1177
+ term = p.outer(p).div_(safe_clip(p.dot(y)))
1178
+ H.sub_(term, alpha=1-alpha**2)
1179
+ return H
1180
+
1181
+ class ShorR(HessianUpdateStrategy):
1182
+ """Shor’s r-algorithm.
1183
+
1184
+ Note:
1185
+ A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
1186
+ Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
1187
+ so setting ``a_init`` in the line search is recommended.
682
1188
 
1189
+ References:
1190
+ S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.
1191
+
1192
+ Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.
1193
+
1194
+ Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
1195
+ """
1196
+
1197
+ def __init__(
1198
+ self,
1199
+ alpha=0.5,
1200
+ init_scale: float | Literal["auto"] = 1,
1201
+ tol: float = 1e-32,
1202
+ ptol: float | None = 1e-32,
1203
+ ptol_restart: bool = False,
1204
+ gtol: float | None = 1e-32,
1205
+ restart_interval: int | None | Literal['auto'] = None,
1206
+ beta: float | None = None,
1207
+ update_freq: int = 1,
1208
+ scale_first: bool = False,
1209
+ concat_params: bool = True,
1210
+ # inverse: bool = True,
1211
+ inner: Chainable | None = None,
1212
+ ):
1213
+ defaults = dict(alpha=alpha)
1214
+ super().__init__(
1215
+ defaults=defaults,
1216
+ init_scale=init_scale,
1217
+ tol=tol,
1218
+ ptol=ptol,
1219
+ ptol_restart=ptol_restart,
1220
+ gtol=gtol,
1221
+ restart_interval=restart_interval,
1222
+ beta=beta,
1223
+ update_freq=update_freq,
1224
+ scale_first=scale_first,
1225
+ concat_params=concat_params,
1226
+ inverse=True,
1227
+ inner=inner,
1228
+ )
683
1229
 
1230
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1231
+ return shor_r_(H=H, y=y, alpha=setting['alpha'])