torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,37 +1,27 @@
1
- """Use BFGS or maybe SR1."""
1
+ import warnings
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import Mapping, Callable
3
+ from collections.abc import Callable, Mapping
4
4
  from typing import Any, Literal
5
- import warnings
6
5
 
7
6
  import torch
8
7
 
9
8
  from ...core import Chainable, Module, TensorwiseTransform, Transform
10
- from ...utils import TensorList, set_storage_, unpack_states
11
- from ..functional import safe_scaling_
9
+ from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
10
+ from ...utils.linalg import linear_operator
11
+ from ..functional import initial_step_size, safe_clip
12
12
 
13
13
 
14
- def _safe_dict_update_(d1_:dict, d2:dict):
15
- inter = set(d1_.keys()).intersection(d2.keys())
16
- if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
17
- d1_.update(d2)
18
14
 
19
15
  def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
20
16
  if (beta is None) or (beta == 0) or (key not in state): state[key] = value
21
17
  elif state[key].shape != value.shape: state[key] = value
22
18
  else: state[key].lerp_(value, 1-beta)
23
19
 
24
- def _safe_clip(x: torch.Tensor):
25
- """makes sure scalar tensor x is not smaller than epsilon"""
26
- assert x.numel() == 1, x.shape
27
- eps = torch.finfo(x.dtype).eps ** 2
28
- if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
29
- return x
30
-
31
20
  class HessianUpdateStrategy(TensorwiseTransform, ABC):
32
21
  """Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
33
22
 
34
- This is an abstract class, to use it, subclass it and override `update_H` and/or `update_B`.
23
+ This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
24
+ and if necessary, ``initialize_P``, ``modify_H`` and ``modify_B``.
35
25
 
36
26
  Args:
37
27
  defaults (dict | None, optional): defaults. Defaults to None.
@@ -42,13 +32,13 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
42
32
 
43
33
  Defaults to "auto".
44
34
  tol (float, optional):
45
- algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-8.
35
+ algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-32.
46
36
  ptol (float | None, optional):
47
- tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
48
- ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
37
+ tolerance for minimal parameter difference to avoid instability. Defaults to 1e-32.
38
+ ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
49
39
  gtol (float | None, optional):
50
- tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
51
- reset_interval (int | None | Literal["auto"], optional):
40
+ tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-32.
41
+ restart_interval (int | None | Literal["auto"], optional):
52
42
  interval between resetting the hessian approximation.
53
43
 
54
44
  "auto" corresponds to number of decision variables + 1.
@@ -70,141 +60,101 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
70
60
  Defaults to True.
71
61
  inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
72
62
 
73
- Example:
74
- Implementing BFGS method that maintains an estimate of the hessian inverse (H):
75
-
76
- .. code-block:: python
77
-
78
- class BFGS(HessianUpdateStrategy):
79
- def __init__(
80
- self,
81
- init_scale: float | Literal["auto"] = "auto",
82
- tol: float = 1e-8,
83
- ptol: float = 1e-10,
84
- ptol_reset: bool = False,
85
- reset_interval: int | None = None,
86
- beta: float | None = None,
87
- update_freq: int = 1,
88
- scale_first: bool = True,
89
- scale_second: bool = False,
90
- concat_params: bool = True,
91
- inner: Chainable | None = None,
92
- ):
93
- super().__init__(
94
- defaults=None,
95
- init_scale=init_scale,
96
- tol=tol,
97
- ptol=ptol,
98
- ptol_reset=ptol_reset,
99
- reset_interval=reset_interval,
100
- beta=beta,
101
- update_freq=update_freq,
102
- scale_first=scale_first,
103
- scale_second=scale_second,
104
- concat_params=concat_params,
105
- inverse=True,
106
- inner=inner,
107
- )
108
-
109
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
110
- tol = settings["tol"]
111
- sy = torch.dot(s, y)
112
- if sy <= tol: return H
113
- num1 = (sy + (y @ H @ y)) * s.outer(s)
114
- term1 = num1.div_(sy**2)
115
- num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
116
- term2 = num2.div_(sy)
117
- H += term1.sub_(term2)
118
- return H
63
+ ## Notes
64
+
65
+ ### update
66
+
67
+ On 1st ``update_tensor`` H or B is initialized using ``initialize_P``, which returns identity matrix by default.
68
+
69
+ 2nd and subsequent ``update_tensor`` calls ``update_H`` or ``update_B``.
70
+
71
+ Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
72
+
73
+ ### apply
74
+
75
+ ``apply_tensor`` computes ``H = modify_H(H)`` or ``B = modify_B(B)``, those methods do nothing by default.
76
+
77
+ Then it computes and returns ``H @ input`` or ``solve(B, input)``.
78
+
79
+ Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
80
+
81
+ ### initial scale
82
+
83
+ If ``init_scale`` is a scalar, the preconditioner is multiplied or divided (if inverse) by it on first ``update_tensor``.
119
84
 
85
+ If ``init_scale="auto"``, it is computed and applied on the second ``update_tensor``.
86
+
87
+ ### get_H
88
+
89
+ First it computes ``H = modify_H(H)`` or ``B = modify_B(B)``.
90
+
91
+ Returns a ``Dense`` linear operator with ``B``, or ``DenseInverse`` linear operator with ``H``.
92
+
93
+ But if H/B has 1 dimension, ``Diagonal`` linear operator is returned with ``B`` or ``1/H``.
120
94
  """
121
95
  def __init__(
122
96
  self,
123
97
  defaults: dict | None = None,
124
98
  init_scale: float | Literal["auto"] = "auto",
125
- tol: float = 1e-8,
126
- ptol: float | None = 1e-10,
127
- ptol_reset: bool = False,
128
- gtol: float | None = 1e-10,
129
- reset_interval: int | None | Literal['auto'] = None,
99
+ tol: float = 1e-32,
100
+ ptol: float | None = 1e-32,
101
+ ptol_restart: bool = False,
102
+ gtol: float | None = 1e-32,
103
+ restart_interval: int | None | Literal['auto'] = None,
130
104
  beta: float | None = None,
131
105
  update_freq: int = 1,
132
- scale_first: bool = True,
133
- scale_second: bool = False,
106
+ scale_first: bool = False,
134
107
  concat_params: bool = True,
135
108
  inverse: bool = True,
136
109
  inner: Chainable | None = None,
137
110
  ):
138
111
  if defaults is None: defaults = {}
139
- _safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_reset=ptol_reset, gtol=gtol, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
140
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
141
-
142
- def _init_M(self, size:int, device, dtype, is_inverse:bool):
143
- return torch.eye(size, device=device, dtype=dtype)
112
+ safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
113
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
144
114
 
145
- def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
146
- """returns multiplier to H or B"""
147
- ys = y.dot(s)
148
- yy = y.dot(y)
149
- if ys != 0 and yy != 0: return yy/ys
150
- return 1
115
+ def reset_for_online(self):
116
+ super().reset_for_online()
117
+ self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
151
118
 
152
- def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
153
- set_storage_(M, self._init_M(s.numel(), device=M.device, dtype=M.dtype, is_inverse=inverse))
154
- if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
155
- if init_scale >= 1:
156
- if inverse: M /= init_scale
157
- else: M *= init_scale
119
+ # ---------------------------- methods to override --------------------------- #
120
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool) -> torch.Tensor:
121
+ """returns the initial torch.Tensor for H or B"""
122
+ return torch.eye(size, device=device, dtype=dtype)
158
123
 
159
124
  def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
160
125
  p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
161
126
  """update hessian inverse"""
162
- raise NotImplementedError
127
+ raise NotImplementedError(f"hessian inverse approximation is not implemented for {self.__class__.__name__}.")
163
128
 
164
129
  def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
165
130
  p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
166
131
  """update hessian"""
167
- raise NotImplementedError
168
-
169
- def reset_for_online(self):
170
- super().reset_for_online()
171
- self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
132
+ raise NotImplementedError(f"{self.__class__.__name__} only supports hessian inverse approximation. "
133
+ "Remove the `inverse=False` argument when initializing this module.")
172
134
 
173
- def get_B(self) -> tuple[torch.Tensor, bool]:
174
- """returns (B or H, is_inverse)."""
175
- state = next(iter(self.state.values()))
176
- if "B" in state: return state["B"], False
177
- return state["H"], True
135
+ def modify_B(self, B: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
136
+ """modifies B out of place before appling the update rule, doesn't affect the buffer B."""
137
+ return B
178
138
 
179
- def get_H(self) -> tuple[torch.Tensor, bool]:
180
- """returns (H or B, is_inverse)."""
181
- state = next(iter(self.state.values()))
182
- if "H" in state: return state["H"], False
183
- return state["B"], True
184
-
185
- def make_Bv(self) -> Callable[[torch.Tensor], torch.Tensor]:
186
- B, is_inverse = self.get_B()
187
-
188
- if is_inverse:
189
- H=B
190
- warnings.warn(f'{self} maintains H, so Bv will be inefficient!')
191
- def Hxv(v): return torch.linalg.solve_ex(H, v)[0] # pylint:disable=not-callable
192
- return Hxv
193
-
194
- def Bv(v): return B@v
195
- return Bv
196
-
197
- def make_Hv(self) -> Callable[[torch.Tensor], torch.Tensor]:
198
- H, is_inverse = self.get_H()
139
+ def modify_H(self, H: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
140
+ """modifies H out of place before appling the update rule, doesn't affect the buffer H."""
141
+ return H
199
142
 
200
- if is_inverse:
201
- B=H
202
- warnings.warn(f'{self} maintains B, so Hv will be inefficient!')
203
- def Bxv(v): return torch.linalg.solve_ex(B, v)[0] # pylint:disable=not-callable
204
- return Bxv
143
+ # ------------------------------ common methods ------------------------------ #
144
+ def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
145
+ """returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
146
+ ys = y.dot(s)
147
+ yy = y.dot(y)
148
+ if ys != 0 and yy != 0: return yy/ys
149
+ return 1
205
150
 
206
- def Hv(v): return H@v
207
- return Hv
151
+ def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
152
+ """resets ``P`` which is either B or H"""
153
+ set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
154
+ if init_scale == 'auto': init_scale = self.auto_initial_scale(s,y)
155
+ if init_scale >= 1:
156
+ if inverse: P /= init_scale
157
+ else: P *= init_scale
208
158
 
209
159
  @torch.no_grad
210
160
  def update_tensor(self, tensor, param, grad, loss, state, setting):
@@ -216,14 +166,14 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
216
166
  state['step'] = step
217
167
  init_scale = setting['init_scale']
218
168
  ptol = setting['ptol']
219
- ptol_reset = setting['ptol_reset']
169
+ ptol_restart = setting['ptol_restart']
220
170
  gtol = setting['gtol']
221
- reset_interval = setting['reset_interval']
222
- if reset_interval == 'auto': reset_interval = tensor.numel() + 1
171
+ restart_interval = setting['restart_interval']
172
+ if restart_interval == 'auto': restart_interval = tensor.numel() + 1
223
173
 
224
174
  if M is None or 'f_prev' not in state:
225
175
  if M is None: # won't be true on reset_for_online
226
- M = self._init_M(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
176
+ M = self.initialize_P(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
227
177
  if isinstance(init_scale, (int, float)) and init_scale != 1:
228
178
  if inverse: M /= init_scale
229
179
  else: M *= init_scale
@@ -242,13 +192,13 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
242
192
  state['p_prev'].copy_(p)
243
193
  state['g_prev'].copy_(g)
244
194
 
245
- if reset_interval is not None and step % reset_interval == 0:
246
- self._reset_M_(M, s, y, inverse, init_scale, state)
195
+ if restart_interval is not None and step % restart_interval == 0:
196
+ self.reset_P(M, s, y, inverse, init_scale, state)
247
197
  return
248
198
 
249
199
  # tolerance on parameter difference to avoid exploding after converging
250
200
  if ptol is not None and s.abs().max() <= ptol:
251
- if ptol_reset: self._reset_M_(M, s, y, inverse, init_scale, state) # reset history
201
+ if ptol_restart: self.reset_P(M, s, y, inverse, init_scale, state) # reset history
252
202
  return
253
203
 
254
204
  # tolerance on gradient difference to avoid exploding when there is no curvature
@@ -256,8 +206,8 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
256
206
  return
257
207
 
258
208
  if step == 2 and init_scale == 'auto':
259
- if inverse: M /= self._get_init_scale(s,y)
260
- else: M *= self._get_init_scale(s,y)
209
+ if inverse: M /= self.auto_initial_scale(s,y)
210
+ else: M *= self.auto_initial_scale(s,y)
261
211
 
262
212
  beta = setting['beta']
263
213
  if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
@@ -272,72 +222,86 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
272
222
 
273
223
  state['f_prev'] = loss
274
224
 
275
- def _post_B(self, B: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
276
- """modifies B before appling the update rule. Must return (B, g)"""
277
- return B, g
278
-
279
- def _post_H(self, H: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
280
- """modifies H before appling the update rule. Must return (H, g)"""
281
- return H, g
282
-
283
225
  @torch.no_grad
284
226
  def apply_tensor(self, tensor, param, grad, loss, state, setting):
285
- step = state.get('step', 0)
227
+ step = state['step']
286
228
 
287
- if setting['scale_second'] and step == 2:
288
- tensor = safe_scaling_(tensor)
229
+ if setting['scale_first'] and step == 1:
230
+ tensor *= initial_step_size(tensor)
289
231
 
290
232
  inverse = setting['inverse']
233
+ g = tensor.view(-1)
234
+
291
235
  if inverse:
292
236
  H = state['H']
293
- H, g = self._post_H(H, tensor.view(-1), state, setting)
237
+ H = self.modify_H(H, state, setting)
294
238
  if H.ndim == 1: return g.mul_(H).view_as(tensor)
295
239
  return (H @ g).view_as(tensor)
296
240
 
297
241
  B = state['B']
298
- H, g = self._post_B(B, tensor.view(-1), state, setting)
242
+ B = self.modify_B(B, state, setting)
299
243
 
300
244
  if B.ndim == 1: return g.div_(B).view_as(tensor)
301
245
  x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
302
246
  if info == 0: return x.view_as(tensor)
303
- return safe_scaling_(tensor)
247
+
248
+ # failed to solve linear system, so reset state
249
+ self.state.clear()
250
+ self.global_state.clear()
251
+ return tensor.mul_(initial_step_size(tensor))
252
+
253
+ def get_H(self, var):
254
+ param = var.params[0]
255
+ state = self.state[param]
256
+ settings = self.settings[param]
257
+ if "B" in state:
258
+ B = self.modify_B(state["B"], state, settings)
259
+ if B.ndim == 2: return linear_operator.Dense(B)
260
+ assert B.ndim == 1, B.shape
261
+ return linear_operator.Diagonal(B)
262
+
263
+ if "H" in state:
264
+ H = self.modify_H(state["H"], state, settings)
265
+ if H.ndim != 1: return linear_operator.DenseInverse(H)
266
+ return linear_operator.Diagonal(1/H)
267
+
268
+ return None
304
269
 
305
270
  class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
306
- '''This is :code:`HessianUpdateStrategy` subclass for algorithms with no extra defaults, to skip the lengthy __init__.
307
- Refer to :code:`HessianUpdateStrategy` documentation.
308
-
309
- Example:
310
- Implementing BFGS method that maintains an estimate of the hessian inverse (H):
311
-
312
- .. code-block:: python
313
-
314
- class BFGS(_HessianUpdateStrategyDefaults):
315
- """Broyden–Fletcher–Goldfarb–Shanno algorithm"""
316
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
317
- tol = settings["tol"]
318
- sy = torch.dot(s, y)
319
- if sy <= tol: return H
320
- num1 = (sy + (y @ H @ y)) * s.outer(s)
321
- term1 = num1.div_(sy**2)
322
- num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
323
- term2 = num2.div_(sy)
324
- H += term1.sub_(term2)
325
- return H
271
+ '''This is ``HessianUpdateStrategy`` subclass for algorithms with no extra defaults, to skip the lengthy ``__init__``.
272
+ Refer to ``HessianUpdateStrategy`` documentation.
273
+
274
+ ## Example:
275
+
276
+ Implementing BFGS method that maintains an estimate of the hessian inverse (H):
277
+ ```python
278
+ class BFGS(_HessianUpdateStrategyDefaults):
279
+ """Broyden–Fletcher–Goldfarb–Shanno algorithm"""
280
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
281
+ tol = settings["tol"]
282
+ sy = torch.dot(s, y)
283
+ if sy <= tol: return H
284
+ num1 = (sy + (y @ H @ y)) * s.outer(s)
285
+ term1 = num1.div_(sy**2)
286
+ num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
287
+ term2 = num2.div_(sy)
288
+ H += term1.sub_(term2)
289
+ return H
290
+ ```
326
291
 
327
292
  Make sure to put at least a basic class level docstring to overwrite this.
328
293
  '''
329
294
  def __init__(
330
295
  self,
331
296
  init_scale: float | Literal["auto"] = "auto",
332
- tol: float = 1e-8,
333
- ptol: float | None = 1e-10,
334
- ptol_reset: bool = False,
335
- gtol: float | None = 1e-10,
336
- reset_interval: int | None = None,
297
+ tol: float = 1e-32,
298
+ ptol: float | None = 1e-32,
299
+ ptol_restart: bool = False,
300
+ gtol: float | None = 1e-32,
301
+ restart_interval: int | None = None,
337
302
  beta: float | None = None,
338
303
  update_freq: int = 1,
339
- scale_first: bool = True,
340
- scale_second: bool = False,
304
+ scale_first: bool = False,
341
305
  concat_params: bool = True,
342
306
  inverse: bool = True,
343
307
  inner: Chainable | None = None,
@@ -347,13 +311,12 @@ class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
347
311
  init_scale=init_scale,
348
312
  tol=tol,
349
313
  ptol=ptol,
350
- ptol_reset=ptol_reset,
314
+ ptol_restart=ptol_restart,
351
315
  gtol=gtol,
352
- reset_interval=reset_interval,
316
+ restart_interval=restart_interval,
353
317
  beta=beta,
354
318
  update_freq=update_freq,
355
319
  scale_first=scale_first,
356
- scale_second=scale_second,
357
320
  concat_params=concat_params,
358
321
  inverse=inverse,
359
322
  inner=inner,
@@ -363,15 +326,14 @@ class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
363
326
  def __init__(
364
327
  self,
365
328
  init_scale: float | Literal["auto"] = "auto",
366
- tol: float = 1e-8,
367
- ptol: float | None = 1e-10,
368
- ptol_reset: bool = False,
369
- gtol: float | None = 1e-10,
370
- reset_interval: int | None = None,
329
+ tol: float = 1e-32,
330
+ ptol: float | None = 1e-32,
331
+ ptol_restart: bool = False,
332
+ gtol: float | None = 1e-32,
333
+ restart_interval: int | None = None,
371
334
  beta: float | None = None,
372
335
  update_freq: int = 1,
373
- scale_first: bool = True,
374
- scale_second: bool = False,
336
+ scale_first: bool = False,
375
337
  concat_params: bool = True,
376
338
  inverse: bool = False,
377
339
  inner: Chainable | None = None,
@@ -381,13 +343,12 @@ class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
381
343
  init_scale=init_scale,
382
344
  tol=tol,
383
345
  ptol=ptol,
384
- ptol_reset=ptol_reset,
346
+ ptol_restart=ptol_restart,
385
347
  gtol=gtol,
386
- reset_interval=reset_interval,
348
+ restart_interval=restart_interval,
387
349
  beta=beta,
388
350
  update_freq=update_freq,
389
351
  scale_first=scale_first,
390
- scale_second=scale_second,
391
352
  concat_params=concat_params,
392
353
  inverse=inverse,
393
354
  inner=inner,
@@ -399,7 +360,7 @@ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
399
360
  if sy < tol: return B
400
361
 
401
362
  Bs = B@s
402
- sBs = _safe_clip(s.dot(Bs))
363
+ sBs = safe_clip(s.dot(Bs))
403
364
 
404
365
  term1 = y.outer(y).div_(sy)
405
366
  term2 = (Bs.outer(s) @ B.T).div_(sBs)
@@ -410,7 +371,7 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
410
371
  sy = s.dot(y)
411
372
  if sy <= tol: return H
412
373
 
413
- sy_sq = _safe_clip(sy**2)
374
+ sy_sq = safe_clip(sy**2)
414
375
 
415
376
  Hy = H@y
416
377
  scale1 = (sy + y.dot(Hy)) / sy_sq
@@ -425,11 +386,11 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
425
386
  class BFGS(_InverseHessianUpdateStrategyDefaults):
426
387
  """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
427
388
 
428
- .. note::
429
- a line search such as :code:`tz.m.StrongWolfe()` is recommended, although this can be stable without a line search. Alternatively warmup :code:`tz.m.Warmup` can stabilize quasi-newton methods without line search.
389
+ Note:
390
+ a line search or a trust region is recommended
430
391
 
431
- .. warning::
432
- this uses roughly O(N^2) memory.
392
+ Warning:
393
+ this uses at least O(N^2) memory.
433
394
 
434
395
  Args:
435
396
  init_scale (float | Literal["auto"], optional):
@@ -439,12 +400,12 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
439
400
 
440
401
  Defaults to "auto".
441
402
  tol (float, optional):
442
- tolerance on curvature condition. Defaults to 1e-8.
403
+ tolerance on curvature condition. Defaults to 1e-32.
443
404
  ptol (float | None, optional):
444
405
  skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
445
- Defaults to 1e-10.
446
- ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
447
- reset_interval (int | None | Literal["auto"], optional):
406
+ Defaults to 1e-32.
407
+ ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
408
+ restart_interval (int | None | Literal["auto"], optional):
448
409
  interval between resetting the hessian approximation.
449
410
 
450
411
  "auto" corresponds to number of decision variables + 1.
@@ -462,26 +423,25 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
462
423
  If False, the update rule is applied to each parameter separately. Defaults to True.
463
424
  inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
464
425
 
465
- Examples:
466
- BFGS with strong-wolfe line search:
467
-
468
- .. code-block:: python
469
-
470
- opt = tz.Modular(
471
- model.parameters(),
472
- tz.m.BFGS(),
473
- tz.m.StrongWolfe()
474
- )
475
-
476
- BFGS preconditioning applied to momentum:
477
-
478
- .. code-block:: python
479
-
480
- opt = tz.Modular(
481
- model.parameters(),
482
- tz.m.BFGS(inner=tz.m.EMA(0.9)),
483
- tz.m.LR(1e-2)
484
- )
426
+ ## Examples:
427
+
428
+ BFGS with backtracking line search:
429
+
430
+ ```python
431
+ opt = tz.Modular(
432
+ model.parameters(),
433
+ tz.m.BFGS(),
434
+ tz.m.Backtracking()
435
+ )
436
+ ```
437
+
438
+ BFGS with trust region
439
+ ```python
440
+ opt = tz.Modular(
441
+ model.parameters(),
442
+ tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
443
+ )
444
+ ```
485
445
  """
486
446
 
487
447
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
@@ -501,38 +461,29 @@ def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
501
461
 
502
462
  # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
503
463
  if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
504
- H += z.outer(z).div_(_safe_clip(denom))
464
+ H += z.outer(z).div_(safe_clip(denom))
505
465
  return H
506
466
 
507
467
  class SR1(_InverseHessianUpdateStrategyDefaults):
508
- """Symmetric Rank 1 Quasi-Newton method.
509
-
510
- .. note::
511
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
512
-
513
- .. note::
514
- approximate Hessians generated by the SR1 method show faster progress towards the true Hessian than other methods, but it is more unstable. SR1 is best used within a trust region module.
515
-
516
- .. note::
517
- SR1 doesn't enforce the hessian estimate to be positive definite, therefore it can generate directions that are not descent directions.
518
-
519
- .. warning::
520
- this uses roughly O(N^2) memory.
468
+ """Symmetric Rank 1. This works best with a trust region:
469
+ ```python
470
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))
471
+ ```
521
472
 
522
473
  Args:
523
474
  init_scale (float | Literal["auto"], optional):
524
475
  initial hessian matrix is set to identity times this.
525
476
 
526
- "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
477
+ "auto" corresponds to a heuristic from [1] p.142-143.
527
478
 
528
479
  Defaults to "auto".
529
480
  tol (float, optional):
530
- tolerance for denominator in SR1 update rule as in Nocedal, Wright. “Numerical optimization” 2nd p.146. Defaults to 1e-8.
481
+ tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.
531
482
  ptol (float | None, optional):
532
483
  skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
533
- Defaults to 1e-10.
534
- ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
535
- reset_interval (int | None | Literal["auto"], optional):
484
+ Defaults to 1e-32.
485
+ ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
486
+ restart_interval (int | None | Literal["auto"], optional):
536
487
  interval between resetting the hessian approximation.
537
488
 
538
489
  "auto" corresponds to number of decision variables + 1.
@@ -550,26 +501,18 @@ class SR1(_InverseHessianUpdateStrategyDefaults):
550
501
  If False, the update rule is applied to each parameter separately. Defaults to True.
551
502
  inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
552
503
 
553
- Examples:
554
- SR1 with strong-wolfe line search
504
+ ### Examples:
555
505
 
556
- .. code-block:: python
506
+ SR1 with trust region
507
+ ```python
508
+ opt = tz.Modular(
509
+ model.parameters(),
510
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
511
+ )
512
+ ```
557
513
 
558
- opt = tz.Modular(
559
- model.parameters(),
560
- tz.m.SR1(),
561
- tz.m.StrongWolfe()
562
- )
563
-
564
- BFGS preconditioning applied to momentum
565
-
566
- .. code-block:: python
567
-
568
- opt = tz.Modular(
569
- model.parameters(),
570
- tz.m.SR1(inner=tz.m.EMA(0.9)),
571
- tz.m.LR(1e-2)
572
- )
514
+ ### References:
515
+ [1]. Nocedal. Stephen J. Wright. Numerical Optimization
573
516
  """
574
517
 
575
518
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
@@ -584,7 +527,7 @@ def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
584
527
  if sy.abs() <= tol: return H
585
528
  term1 = s.outer(s).div_(sy)
586
529
 
587
- yHy = _safe_clip(y.dot(H @ y))
530
+ yHy = safe_clip(y.dot(H @ y))
588
531
 
589
532
  num = (H @ y).outer(y) @ H
590
533
  term2 = num.div_(yHy)
@@ -607,15 +550,11 @@ def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
607
550
  class DFP(_InverseHessianUpdateStrategyDefaults):
608
551
  """Davidon–Fletcher–Powell Quasi-Newton method.
609
552
 
610
- .. note::
611
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
612
-
613
- .. note::
614
- BFGS is the recommended QN method and will usually outperform this.
615
-
616
- .. warning::
617
- this uses roughly O(N^2) memory.
553
+ Note:
554
+ a trust region or an accurate line search is recommended.
618
555
 
556
+ Warning:
557
+ this uses at least O(N^2) memory.
619
558
  """
620
559
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
621
560
  return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
@@ -629,30 +568,30 @@ class DFP(_InverseHessianUpdateStrategyDefaults):
629
568
 
630
569
  def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
631
570
  c = H.T @ s
632
- cy = _safe_clip(c.dot(y))
571
+ cy = safe_clip(c.dot(y))
633
572
  num = (H@y).sub_(s).outer(c)
634
573
  H -= num/cy
635
574
  return H
636
575
  def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
637
576
  r = y - B@s
638
- ss = _safe_clip(s.dot(s))
577
+ ss = safe_clip(s.dot(s))
639
578
  B += r.outer(s).div_(ss)
640
579
  return B
641
580
 
642
581
  def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
643
- yy = _safe_clip(y.dot(y))
582
+ yy = safe_clip(y.dot(y))
644
583
  num = (s - (H @ y)).outer(y)
645
584
  H += num/yy
646
585
  return H
647
586
  def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
648
587
  r = y - B@s
649
- ys = _safe_clip(y.dot(s))
588
+ ys = safe_clip(y.dot(s))
650
589
  B += r.outer(y).div_(ys)
651
590
  return B
652
591
 
653
592
  def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
654
593
  c = g_prev
655
- cy = _safe_clip(c.dot(y))
594
+ cy = safe_clip(c.dot(y))
656
595
  num = (H@y).sub_(s).outer(c)
657
596
  H -= num/cy
658
597
  return H
@@ -660,7 +599,7 @@ def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torc
660
599
  def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
661
600
  Hy = H @ y
662
601
  c = H @ Hy # pylint:disable=not-callable
663
- cy = _safe_clip(c.dot(y))
602
+ cy = safe_clip(c.dot(y))
664
603
  num = Hy.sub_(s).outer(c)
665
604
  H -= num/cy
666
605
  return H
@@ -668,14 +607,11 @@ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
668
607
  class BroydenGood(_InverseHessianUpdateStrategyDefaults):
669
608
  """Broyden's "good" Quasi-Newton method.
670
609
 
671
- .. note::
672
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
610
+ Note:
611
+ a trust region or an accurate line search is recommended.
673
612
 
674
- .. note::
675
- BFGS is the recommended QN method and will usually outperform this.
676
-
677
- .. warning::
678
- this uses roughly O(N^2) memory.
613
+ Warning:
614
+ this uses at least O(N^2) memory.
679
615
 
680
616
  Reference:
681
617
  Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
@@ -688,14 +624,11 @@ class BroydenGood(_InverseHessianUpdateStrategyDefaults):
688
624
  class BroydenBad(_InverseHessianUpdateStrategyDefaults):
689
625
  """Broyden's "bad" Quasi-Newton method.
690
626
 
691
- .. note::
692
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
693
-
694
- .. note::
695
- BFGS is the recommended QN method and will usually outperform this.
627
+ Note:
628
+ a trust region or an accurate line search is recommended.
696
629
 
697
- .. warning::
698
- this uses roughly O(N^2) memory.
630
+ Warning:
631
+ this uses at least O(N^2) memory.
699
632
 
700
633
  Reference:
701
634
  Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
@@ -708,14 +641,11 @@ class BroydenBad(_InverseHessianUpdateStrategyDefaults):
708
641
  class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
709
642
  """Greenstadt's first Quasi-Newton method.
710
643
 
711
- .. note::
712
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
644
+ Note:
645
+ a trust region or an accurate line search is recommended.
713
646
 
714
- .. note::
715
- BFGS is the recommended QN method and will usually outperform this.
716
-
717
- .. warning::
718
- this uses roughly O(N^2) memory.
647
+ Warning:
648
+ this uses at least O(N^2) memory.
719
649
 
720
650
  Reference:
721
651
  Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
@@ -726,18 +656,14 @@ class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
726
656
  class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
727
657
  """Greenstadt's second Quasi-Newton method.
728
658
 
729
- .. note::
730
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
659
+ Note:
660
+ a line search is recommended.
731
661
 
732
- .. note::
733
- BFGS is the recommended QN method and will usually outperform this.
734
-
735
- .. warning::
736
- this uses roughly O(N^2) memory.
662
+ Warning:
663
+ this uses at least O(N^2) memory.
737
664
 
738
665
  Reference:
739
666
  Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
740
-
741
667
  """
742
668
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
743
669
  return greenstadt2_H_(H=H, s=s, y=y)
@@ -746,7 +672,7 @@ class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
746
672
  def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
747
673
  j = y.abs().argmax()
748
674
 
749
- denom = _safe_clip(y[j])
675
+ denom = safe_clip(y[j])
750
676
 
751
677
  Hy = H @ y.unsqueeze(1)
752
678
  num = s.unsqueeze(1) - Hy
@@ -759,11 +685,11 @@ class ICUM(_InverseHessianUpdateStrategyDefaults):
759
685
  Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
760
686
  due to only updating one column of the inverse hessian approximation per step.
761
687
 
762
- .. note::
763
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
688
+ Note:
689
+ a line search is recommended.
764
690
 
765
- .. warning::
766
- this uses roughly O(N^2) memory.
691
+ Warning:
692
+ this uses at least O(N^2) memory.
767
693
 
768
694
  Reference:
769
695
  Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
@@ -775,11 +701,11 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor)
775
701
  s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
776
702
  I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
777
703
  d = (R + I * (s_norm/2)) @ s
778
- ds = _safe_clip(d.dot(s))
704
+ ds = safe_clip(d.dot(s))
779
705
  R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
780
706
 
781
707
  c = H.T @ d
782
- cy = _safe_clip(c.dot(y))
708
+ cy = safe_clip(c.dot(y))
783
709
  num = (H@y).sub_(s).outer(c)
784
710
  H -= num/cy
785
711
  return H, R
@@ -788,14 +714,11 @@ class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
788
714
  """
789
715
  Thomas's "optimal" Quasi-Newton method.
790
716
 
791
- .. note::
792
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
717
+ Note:
718
+ a line search is recommended.
793
719
 
794
- .. note::
795
- BFGS is the recommended QN method and will usually outperform this.
796
-
797
- .. warning::
798
- this uses roughly O(N^2) memory.
720
+ Warning:
721
+ this uses at least O(N^2) memory.
799
722
 
800
723
  Reference:
801
724
  Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
@@ -805,18 +728,18 @@ class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
805
728
  H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
806
729
  return H
807
730
 
808
- def _reset_M_(self, M, s, y,inverse, init_scale, state):
809
- super()._reset_M_(M, s, y, inverse, init_scale, state)
731
+ def reset_P(self, P, s, y, inverse, init_scale, state):
732
+ super().reset_P(P, s, y, inverse, init_scale, state)
810
733
  for st in self.state.values():
811
734
  st.pop("R", None)
812
735
 
813
736
  # ------------------------ powell's symmetric broyden ------------------------ #
814
737
  def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
815
738
  y_Bs = y - B@s
816
- ss = _safe_clip(s.dot(s))
739
+ ss = safe_clip(s.dot(s))
817
740
  num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
818
741
  term1 = num1.div_(ss)
819
- term2 = s.outer(s).mul_(y_Bs.dot(s)/(_safe_clip(ss**2)))
742
+ term2 = s.outer(s).mul_(y_Bs.dot(s)/(safe_clip(ss**2)))
820
743
  B += term1.sub_(term2)
821
744
  return B
822
745
 
@@ -824,14 +747,11 @@ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
824
747
  class PSB(_HessianUpdateStrategyDefaults):
825
748
  """Powell's Symmetric Broyden Quasi-Newton method.
826
749
 
827
- .. note::
828
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
829
-
830
- .. note::
831
- BFGS is the recommended QN method and will usually outperform this.
750
+ Note:
751
+ a line search or a trust region is recommended.
832
752
 
833
- .. warning::
834
- this uses roughly O(N^2) memory.
753
+ Warning:
754
+ this uses at least O(N^2) memory.
835
755
 
836
756
  Reference:
837
757
  Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
@@ -843,7 +763,7 @@ class PSB(_HessianUpdateStrategyDefaults):
843
763
  # Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
844
764
  def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
845
765
  Hy = H@y
846
- yHy = _safe_clip(y.dot(Hy))
766
+ yHy = safe_clip(y.dot(Hy))
847
767
  num = (s - Hy).outer(Hy)
848
768
  H += num.div_(yHy)
849
769
  return H
@@ -852,14 +772,11 @@ class Pearson(_InverseHessianUpdateStrategyDefaults):
852
772
  """
853
773
  Pearson's Quasi-Newton method.
854
774
 
855
- .. note::
856
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
775
+ Note:
776
+ a line search is recommended.
857
777
 
858
- .. note::
859
- BFGS is the recommended QN method and will usually outperform this.
860
-
861
- .. warning::
862
- this uses roughly O(N^2) memory.
778
+ Warning:
779
+ this uses at least O(N^2) memory.
863
780
 
864
781
  Reference:
865
782
  Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
@@ -868,7 +785,7 @@ class Pearson(_InverseHessianUpdateStrategyDefaults):
868
785
  return pearson_H_(H=H, s=s, y=y)
869
786
 
870
787
  def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
871
- sy = _safe_clip(s.dot(y))
788
+ sy = safe_clip(s.dot(y))
872
789
  num = (s - H@y).outer(s)
873
790
  H += num.div_(sy)
874
791
  return H
@@ -876,14 +793,11 @@ def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
876
793
  class McCormick(_InverseHessianUpdateStrategyDefaults):
877
794
  """McCormicks's Quasi-Newton method.
878
795
 
879
- .. note::
880
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
881
-
882
- .. note::
883
- BFGS is the recommended QN method and will usually outperform this.
796
+ Note:
797
+ a line search is recommended.
884
798
 
885
- .. warning::
886
- this uses roughly O(N^2) memory.
799
+ Warning:
800
+ this uses at least O(N^2) memory.
887
801
 
888
802
  Reference:
889
803
  Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
@@ -895,7 +809,7 @@ class McCormick(_InverseHessianUpdateStrategyDefaults):
895
809
 
896
810
  def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
897
811
  Hy = H @ y
898
- yHy = _safe_clip(y.dot(Hy))
812
+ yHy = safe_clip(y.dot(Hy))
899
813
  H -= Hy.outer(Hy) / yHy
900
814
  R += (s - R@y).outer(Hy) / yHy
901
815
  return H, R
@@ -904,14 +818,11 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
904
818
  """
905
819
  Projected Newton Raphson method.
906
820
 
907
- .. note::
908
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
909
-
910
- .. note::
911
- this is an experimental method.
821
+ Note:
822
+ a line search is recommended.
912
823
 
913
- .. warning::
914
- this uses roughly O(N^2) memory.
824
+ Warning:
825
+ this uses at least O(N^2) memory.
915
826
 
916
827
  Reference:
917
828
  Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
@@ -921,15 +832,14 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
921
832
  def __init__(
922
833
  self,
923
834
  init_scale: float | Literal["auto"] = 'auto',
924
- tol: float = 1e-8,
925
- ptol: float | None = 1e-10,
926
- ptol_reset: bool = False,
927
- gtol: float | None = 1e-10,
928
- reset_interval: int | None | Literal['auto'] = 'auto',
835
+ tol: float = 1e-32,
836
+ ptol: float | None = 1e-32,
837
+ ptol_restart: bool = False,
838
+ gtol: float | None = 1e-32,
839
+ restart_interval: int | None | Literal['auto'] = 'auto',
929
840
  beta: float | None = None,
930
841
  update_freq: int = 1,
931
- scale_first: bool = True,
932
- scale_second: bool = False,
842
+ scale_first: bool = False,
933
843
  concat_params: bool = True,
934
844
  inner: Chainable | None = None,
935
845
  ):
@@ -937,13 +847,12 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
937
847
  init_scale=init_scale,
938
848
  tol=tol,
939
849
  ptol = ptol,
940
- ptol_reset=ptol_reset,
850
+ ptol_restart=ptol_restart,
941
851
  gtol=gtol,
942
- reset_interval=reset_interval,
852
+ restart_interval=restart_interval,
943
853
  beta=beta,
944
854
  update_freq=update_freq,
945
855
  scale_first=scale_first,
946
- scale_second=scale_second,
947
856
  concat_params=concat_params,
948
857
  inverse=True,
949
858
  inner=inner,
@@ -955,9 +864,10 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
955
864
  state["R"] = R
956
865
  return H
957
866
 
958
- def _reset_M_(self, M, s, y, inverse, init_scale, state):
867
+ def reset_P(self, P, s, y, inverse, init_scale, state):
959
868
  assert inverse
960
- M.copy_(state["R"])
869
+ if 'R' not in state: state['R'] = torch.eye(P.size(-1), device=P.device, dtype=P.dtype)
870
+ P.copy_(state["R"])
961
871
 
962
872
  # Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
963
873
  def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
@@ -969,8 +879,8 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
969
879
  # however p.12 says eps = gs / gHy
970
880
 
971
881
  Hy = H@y
972
- gHy = _safe_clip(g.dot(Hy))
973
- yHy = _safe_clip(y.dot(Hy))
882
+ gHy = safe_clip(g.dot(Hy))
883
+ yHy = safe_clip(y.dot(Hy))
974
884
  sy = s.dot(y)
975
885
  if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
976
886
 
@@ -987,26 +897,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
987
897
  e = gs / gHy
988
898
  if switch in (1, 3):
989
899
  if e/o <= 1:
990
- phi = e/_safe_clip(o)
900
+ phi = e/safe_clip(o)
991
901
  theta = 0
992
902
  elif o/t >= 1:
993
- phi = o/_safe_clip(t)
903
+ phi = o/safe_clip(t)
994
904
  theta = 1
995
905
  else:
996
906
  phi = 1
997
- denom = _safe_clip(e*t - o**2)
907
+ denom = safe_clip(e*t - o**2)
998
908
  if switch == 1: theta = o * (e - o) / denom
999
909
  else: theta = o * (t - o) / denom
1000
910
 
1001
911
  elif switch == 2:
1002
- t = _safe_clip(t)
1003
- o = _safe_clip(o)
1004
- e = _safe_clip(e)
912
+ t = safe_clip(t)
913
+ o = safe_clip(o)
914
+ e = safe_clip(e)
1005
915
  phi = (e / t) ** 0.5
1006
916
  theta = 1 / (1 + (t*e / o**2)**0.5)
1007
917
 
1008
918
  elif switch == 4:
1009
- phi = e/_safe_clip(t)
919
+ phi = e/safe_clip(t)
1010
920
  theta = 1/2
1011
921
 
1012
922
  else: raise ValueError(switch)
@@ -1028,14 +938,11 @@ class SSVM(HessianUpdateStrategy):
1028
938
  """
1029
939
  Self-scaling variable metric Quasi-Newton method.
1030
940
 
1031
- .. note::
1032
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1033
-
1034
- .. note::
1035
- BFGS is the recommended QN method and will usually outperform this.
941
+ Note:
942
+ a line search is recommended.
1036
943
 
1037
- .. warning::
1038
- this uses roughly O(N^2) memory.
944
+ Warning:
945
+ this uses at least O(N^2) memory.
1039
946
 
1040
947
  Reference:
1041
948
  Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
@@ -1044,15 +951,14 @@ class SSVM(HessianUpdateStrategy):
1044
951
  self,
1045
952
  switch: tuple[float,float] | Literal[1,2,3,4] = 3,
1046
953
  init_scale: float | Literal["auto"] = 'auto',
1047
- tol: float = 1e-8,
1048
- ptol: float | None = 1e-10,
1049
- ptol_reset: bool = False,
1050
- gtol: float | None = 1e-10,
1051
- reset_interval: int | None = None,
954
+ tol: float = 1e-32,
955
+ ptol: float | None = 1e-32,
956
+ ptol_restart: bool = False,
957
+ gtol: float | None = 1e-32,
958
+ restart_interval: int | None = None,
1052
959
  beta: float | None = None,
1053
960
  update_freq: int = 1,
1054
- scale_first: bool = True,
1055
- scale_second: bool = False,
961
+ scale_first: bool = False,
1056
962
  concat_params: bool = True,
1057
963
  inner: Chainable | None = None,
1058
964
  ):
@@ -1062,13 +968,12 @@ class SSVM(HessianUpdateStrategy):
1062
968
  init_scale=init_scale,
1063
969
  tol=tol,
1064
970
  ptol=ptol,
1065
- ptol_reset=ptol_reset,
971
+ ptol_restart=ptol_restart,
1066
972
  gtol=gtol,
1067
- reset_interval=reset_interval,
973
+ restart_interval=restart_interval,
1068
974
  beta=beta,
1069
975
  update_freq=update_freq,
1070
976
  scale_first=scale_first,
1071
- scale_second=scale_second,
1072
977
  concat_params=concat_params,
1073
978
  inverse=True,
1074
979
  inner=inner,
@@ -1083,7 +988,7 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
1083
988
  ys = y.dot(s)
1084
989
  if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
1085
990
  yHy = y.dot(Hy)
1086
- denom = _safe_clip(ys + yHy)
991
+ denom = safe_clip(ys + yHy)
1087
992
 
1088
993
  term1 = 1/denom
1089
994
  term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
@@ -1096,7 +1001,7 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
1096
1001
  return H
1097
1002
 
1098
1003
  def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
1099
- sy = _safe_clip(s.dot(y))
1004
+ sy = safe_clip(s.dot(y))
1100
1005
  return g - (y * (s.dot(g) / sy))
1101
1006
 
1102
1007
 
@@ -1106,16 +1011,16 @@ class GradientCorrection(Transform):
1106
1011
 
1107
1012
  This can useful as inner module for second order methods with inexact line search.
1108
1013
 
1109
- Example:
1110
- L-BFGS with gradient correction
1111
-
1112
- .. code-block :: python
1014
+ ## Example:
1015
+ L-BFGS with gradient correction
1113
1016
 
1114
- opt = tz.Modular(
1115
- model.parameters(),
1116
- tz.m.LBFGS(inner=tz.m.GradientCorrection()),
1117
- tz.m.Backtracking()
1118
- )
1017
+ ```python
1018
+ opt = tz.Modular(
1019
+ model.parameters(),
1020
+ tz.m.LBFGS(inner=tz.m.GradientCorrection()),
1021
+ tz.m.Backtracking()
1022
+ )
1023
+ ```
1119
1024
 
1120
1025
  Reference:
1121
1026
  HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
@@ -1141,14 +1046,11 @@ class Horisho(_InverseHessianUpdateStrategyDefaults):
1141
1046
  """
1142
1047
  Horisho's variable metric Quasi-Newton method.
1143
1048
 
1144
- .. note::
1145
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1146
-
1147
- .. note::
1148
- BFGS is the recommended QN method and will usually outperform this.
1049
+ Note:
1050
+ a line search is recommended.
1149
1051
 
1150
- .. warning::
1151
- this uses roughly O(N^2) memory.
1052
+ Warning:
1053
+ this uses at least O(N^2) memory.
1152
1054
 
1153
1055
  Reference:
1154
1056
  HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
@@ -1175,14 +1077,11 @@ class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
1175
1077
  """
1176
1078
  Fletcher's variable metric Quasi-Newton method.
1177
1079
 
1178
- .. note::
1179
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
1180
-
1181
- .. note::
1182
- BFGS is the recommended QN method and will usually outperform this.
1080
+ Note:
1081
+ a line search is recommended.
1183
1082
 
1184
- .. warning::
1185
- this uses roughly O(N^2) memory.
1083
+ Warning:
1084
+ this uses at least O(N^2) memory.
1186
1085
 
1187
1086
  Reference:
1188
1087
  Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
@@ -1218,10 +1117,10 @@ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol:
1218
1117
  class NewSSM(HessianUpdateStrategy):
1219
1118
  """Self-scaling Quasi-Newton method.
1220
1119
 
1221
- .. note::
1222
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
1120
+ Note:
1121
+ a line search such as ``tz.m.StrongWolfe()`` is required.
1223
1122
 
1224
- .. warning::
1123
+ Warning:
1225
1124
  this uses roughly O(N^2) memory.
1226
1125
 
1227
1126
  Reference:
@@ -1231,15 +1130,14 @@ class NewSSM(HessianUpdateStrategy):
1231
1130
  self,
1232
1131
  type: Literal[1, 2] = 1,
1233
1132
  init_scale: float | Literal["auto"] = "auto",
1234
- tol: float = 1e-8,
1235
- ptol: float | None = 1e-10,
1236
- ptol_reset: bool = False,
1237
- gtol: float | None = 1e-10,
1238
- reset_interval: int | None = None,
1133
+ tol: float = 1e-32,
1134
+ ptol: float | None = 1e-32,
1135
+ ptol_restart: bool = False,
1136
+ gtol: float | None = 1e-32,
1137
+ restart_interval: int | None = None,
1239
1138
  beta: float | None = None,
1240
1139
  update_freq: int = 1,
1241
- scale_first: bool = True,
1242
- scale_second: bool = False,
1140
+ scale_first: bool = False,
1243
1141
  concat_params: bool = True,
1244
1142
  inner: Chainable | None = None,
1245
1143
  ):
@@ -1248,13 +1146,12 @@ class NewSSM(HessianUpdateStrategy):
1248
1146
  init_scale=init_scale,
1249
1147
  tol=tol,
1250
1148
  ptol=ptol,
1251
- ptol_reset=ptol_reset,
1149
+ ptol_restart=ptol_restart,
1252
1150
  gtol=gtol,
1253
- reset_interval=reset_interval,
1151
+ restart_interval=restart_interval,
1254
1152
  beta=beta,
1255
1153
  update_freq=update_freq,
1256
1154
  scale_first=scale_first,
1257
- scale_second=scale_second,
1258
1155
  concat_params=concat_params,
1259
1156
  inverse=True,
1260
1157
  inner=inner,
@@ -1267,44 +1164,48 @@ class NewSSM(HessianUpdateStrategy):
1267
1164
  # ---------------------------- Shor’s r-algorithm ---------------------------- #
1268
1165
  # def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
1269
1166
  # r = B.T @ y
1270
- # r /= torch.linalg.vector_norm(r).clip(min=1e-8) # pylint:disable=not-callable
1167
+ # r /= torch.linalg.vector_norm(r).clip(min=1e-32) # pylint:disable=not-callable
1271
1168
 
1272
1169
  # I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
1273
1170
  # return B @ (I - gamma*r.outer(r))
1274
1171
 
1275
- # this is supposed to be equivalent
1172
+ # this is supposed to be equivalent (and it is)
1276
1173
  def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
1277
1174
  p = H@y
1278
1175
  #(1-y)^2 (ppT)/(pTq)
1279
- term = p.outer(p).div_(p.dot(y).clip(min=1e-8))
1176
+ #term = p.outer(p).div_(p.dot(y).clip(min=1e-32))
1177
+ term = p.outer(p).div_(safe_clip(p.dot(y)))
1280
1178
  H.sub_(term, alpha=1-alpha**2)
1281
1179
  return H
1282
1180
 
1283
1181
  class ShorR(HessianUpdateStrategy):
1284
1182
  """Shor’s r-algorithm.
1285
1183
 
1286
- .. note::
1287
- a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
1184
+ Note:
1185
+ A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
1186
+ Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
1187
+ so setting ``a_init`` in the line search is recommended.
1288
1188
 
1289
- Reference:
1290
- Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.
1189
+ References:
1190
+ S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.
1191
+
1192
+ Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.
1291
1193
 
1292
- Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
1194
+ Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
1293
1195
  """
1294
1196
 
1295
1197
  def __init__(
1296
1198
  self,
1297
1199
  alpha=0.5,
1298
1200
  init_scale: float | Literal["auto"] = 1,
1299
- tol: float = 1e-8,
1300
- ptol: float | None = 1e-10,
1301
- ptol_reset: bool = False,
1302
- gtol: float | None = 1e-10,
1303
- reset_interval: int | None | Literal['auto'] = None,
1201
+ tol: float = 1e-32,
1202
+ ptol: float | None = 1e-32,
1203
+ ptol_restart: bool = False,
1204
+ gtol: float | None = 1e-32,
1205
+ restart_interval: int | None | Literal['auto'] = None,
1304
1206
  beta: float | None = None,
1305
1207
  update_freq: int = 1,
1306
1208
  scale_first: bool = False,
1307
- scale_second: bool = False,
1308
1209
  concat_params: bool = True,
1309
1210
  # inverse: bool = True,
1310
1211
  inner: Chainable | None = None,
@@ -1315,13 +1216,12 @@ class ShorR(HessianUpdateStrategy):
1315
1216
  init_scale=init_scale,
1316
1217
  tol=tol,
1317
1218
  ptol=ptol,
1318
- ptol_reset=ptol_reset,
1219
+ ptol_restart=ptol_restart,
1319
1220
  gtol=gtol,
1320
- reset_interval=reset_interval,
1221
+ restart_interval=restart_interval,
1321
1222
  beta=beta,
1322
1223
  update_freq=update_freq,
1323
1224
  scale_first=scale_first,
1324
- scale_second=scale_second,
1325
1225
  concat_params=concat_params,
1326
1226
  inverse=True,
1327
1227
  inner=inner,