torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,105 @@
1
+ import math
2
+ from typing import Literal, Protocol, overload
3
+
4
+ import torch
5
+
6
+ from ...utils import TensorList
7
+ from ...utils.linalg.linear_operator import DenseInverse, LinearOperator
8
+ from ..functional import safe_clip
9
+
10
+
11
+ class DampingStrategy(Protocol):
12
+ def __call__(
13
+ self,
14
+ s: torch.Tensor,
15
+ y: torch.Tensor,
16
+ g: torch.Tensor,
17
+ H: LinearOperator,
18
+ ) -> tuple[torch.Tensor, torch.Tensor]:
19
+ return s, y
20
+
21
+ def _sy_Hs_sHs(s:torch.Tensor, y:torch.Tensor, H:LinearOperator):
22
+ if isinstance(H, DenseInverse):
23
+ Hs = H.solve(y)
24
+ sHs = y.dot(Hs)
25
+ else:
26
+ Hs = H.matvec(s)
27
+ sHs = s.dot(Hs)
28
+
29
+ return s.dot(y), Hs, sHs
30
+
31
+
32
+
33
+ def powell_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u=0.2):
34
+ # here H is hessian! not the inverse
35
+
36
+ sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
37
+ if sy < u*sHs:
38
+ phi = ((1-u) * sHs) / safe_clip((sHs - sy))
39
+ s = phi * s + (1 - phi) * Hs
40
+
41
+ return s, y
42
+
43
+ def double_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u1=0.2, u2=1/3):
44
+ # Goldfarb, Donald, Yi Ren, and Achraf Bahamou. "Practical quasi-newton methods for training deep neural networks." Advances in Neural Information Processing Systems 33 (2020): 2386-2396.
45
+
46
+ # Powell’s damping on H
47
+ sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
48
+ if sy < u1*sHs:
49
+ phi = ((1-u1) * sHs) / safe_clip(sHs - sy)
50
+ s = phi * s + (1 - phi) * Hs
51
+
52
+ # Powell’s damping with B = I
53
+ sy = s.dot(y)
54
+ ss = s.dot(s)
55
+
56
+ if sy < u2*ss:
57
+ phi = ((1-u2) * ss) / safe_clip(ss - sy)
58
+ y = phi * y + (1 - phi) * s
59
+
60
+ return s, y
61
+
62
+
63
+
64
+ _DAMPING_KEYS = Literal["powell", "double"]
65
+ _DAMPING_STRATEGIES: dict[_DAMPING_KEYS, DampingStrategy] = {
66
+ "powell": powell_damping,
67
+ "double": double_damping,
68
+ }
69
+
70
+
71
+ DampingStrategyType = _DAMPING_KEYS | DampingStrategy | None
72
+
73
+ @overload
74
+ def apply_damping(
75
+ strategy: DampingStrategyType,
76
+ s: torch.Tensor,
77
+ y: torch.Tensor,
78
+ g: torch.Tensor,
79
+ H: LinearOperator,
80
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
81
+ @overload
82
+ def apply_damping(
83
+ strategy: DampingStrategyType,
84
+ s: TensorList,
85
+ y: TensorList,
86
+ g: TensorList,
87
+ H: LinearOperator,
88
+ ) -> tuple[TensorList, TensorList]: ...
89
+ def apply_damping(
90
+ strategy: DampingStrategyType,
91
+ s,
92
+ y,
93
+ g,
94
+ H: LinearOperator,
95
+ ):
96
+ if strategy is None: return s, y
97
+ if isinstance(strategy, str): strategy = _DAMPING_STRATEGIES[strategy]
98
+
99
+ if isinstance(s, TensorList):
100
+ assert isinstance(y, TensorList) and isinstance(g, TensorList)
101
+ s_vec, y_vec = strategy(s.to_vec(), y.to_vec(), g.to_vec(), H)
102
+ return s.from_vec(s_vec), y.from_vec(y_vec)
103
+
104
+ assert isinstance(y, torch.Tensor) and isinstance(g, torch.Tensor)
105
+ return strategy(s, y, g, H)
@@ -1,163 +1,167 @@
1
- from collections.abc import Callable
2
-
3
- import torch
4
-
5
- from .quasi_newton import (
6
- HessianUpdateStrategy,
7
- _HessianUpdateStrategyDefaults,
8
- _InverseHessianUpdateStrategyDefaults,
9
- _safe_clip,
10
- )
11
-
12
-
13
- def _diag_Bv(self: HessianUpdateStrategy):
14
- B, is_inverse = self.get_B()
15
-
16
- if is_inverse:
17
- H=B
18
- def Hxv(v): return v/H
19
- return Hxv
20
-
21
- def Bv(v): return B*v
22
- return Bv
23
-
24
- def _diag_Hv(self: HessianUpdateStrategy):
25
- H, is_inverse = self.get_H()
26
-
27
- if is_inverse:
28
- B=H
29
- def Bxv(v): return v/B
30
- return Bxv
31
-
32
- def Hv(v): return H*v
33
- return Hv
34
-
35
- def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
36
- sy = s.dot(y)
37
- if sy < tol: return H
38
-
39
- sy_sq = _safe_clip(sy**2)
40
-
41
- num1 = (sy + (y * H * y)) * s*s
42
- term1 = num1.div_(sy_sq)
43
- num2 = (H * y * s).add_(s * y * H)
44
- term2 = num2.div_(sy)
45
- H += term1.sub_(term2)
46
- return H
47
-
48
- class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
49
- """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
50
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
51
- return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
52
-
53
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
54
- def make_Bv(self): return _diag_Bv(self)
55
- def make_Hv(self): return _diag_Hv(self)
56
-
57
- def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
58
- z = s - H*y
59
- denom = z.dot(y)
60
-
61
- z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
62
- y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
63
-
64
- # if y_norm*z_norm < tol: return H
65
-
66
- # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
67
- if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
68
- H += (z*z).div_(_safe_clip(denom))
69
- return H
70
- class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
71
- """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
72
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
73
- return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
74
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
75
- return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
76
-
77
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
78
- def make_Bv(self): return _diag_Bv(self)
79
- def make_Hv(self): return _diag_Hv(self)
80
-
81
-
82
-
83
- # Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
84
- def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
85
- denom = _safe_clip((s**4).sum())
86
- num = s.dot(y) - (s*B).dot(s)
87
- B += s**2 * (num/denom)
88
- return B
89
-
90
- class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
91
- """Diagonal quasi-cauchi method.
92
-
93
- Reference:
94
- Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. 1999. – Т. 9. – №. 4. – С. 1192-1204.
95
- """
96
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
97
- return diagonal_qc_B_(B=B, s=s, y=y)
98
-
99
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
100
- def make_Bv(self): return _diag_Bv(self)
101
- def make_Hv(self): return _diag_Hv(self)
102
-
103
- # Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
104
- def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
105
- E_sq = s**2 * B**2
106
- denom = _safe_clip((s*E_sq).dot(s))
107
- num = s.dot(y) - (s*B).dot(s)
108
- B += E_sq * (num/denom)
109
- return B
110
-
111
- class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
112
- """Diagonal quasi-cauchi method.
113
-
114
- Reference:
115
- Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
116
- """
117
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
118
- return diagonal_wqc_B_(B=B, s=s, y=y)
119
-
120
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
121
- def make_Bv(self): return _diag_Bv(self)
122
- def make_Hv(self): return _diag_Hv(self)
123
-
124
-
125
- # Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
126
- def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
127
- denom = _safe_clip((s**4).sum())
128
- num = s.dot(y) + s.dot(s) - (s*B).dot(s)
129
- B += s**2 * (num/denom) - 1
130
- return B
131
-
132
- class DNRTR(_HessianUpdateStrategyDefaults):
133
- """Diagonal quasi-newton method.
134
-
135
- Reference:
136
- Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
137
- """
138
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
139
- return diagonal_wqc_B_(B=B, s=s, y=y)
140
-
141
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
142
- def make_Bv(self): return _diag_Bv(self)
143
- def make_Hv(self): return _diag_Hv(self)
144
-
145
- # Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
146
- def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
147
- denom = _safe_clip((s**4).sum())
148
- num = s.dot(y)
149
- B += s**2 * (num/denom)
150
- return B
151
-
152
- class NewDQN(_HessianUpdateStrategyDefaults):
153
- """Diagonal quasi-newton method.
154
-
155
- Reference:
156
- Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
157
- """
158
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
159
- return new_dqn_B_(B=B, s=s, y=y)
160
-
161
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
162
- def make_Bv(self): return _diag_Bv(self)
163
- def make_Hv(self): return _diag_Hv(self)
1
+ from typing import Any, Literal
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+ from ...core import Chainable
6
+ from .quasi_newton import (
7
+ HessianUpdateStrategy,
8
+ _HessianUpdateStrategyDefaults,
9
+ _InverseHessianUpdateStrategyDefaults,
10
+ )
11
+
12
+ from ..functional import safe_clip
13
+
14
+
15
+ def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
16
+ sy = s.dot(y)
17
+ if sy < tol: return H
18
+
19
+ sy_sq = safe_clip(sy**2)
20
+
21
+ num1 = (sy + (y * H * y)) * s*s
22
+ term1 = num1.div_(sy_sq)
23
+ num2 = (H * y * s).add_(s * y * H)
24
+ term2 = num2.div_(sy)
25
+ H += term1.sub_(term2)
26
+ return H
27
+
28
+ class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
29
+ """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
30
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
31
+ return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
32
+
33
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
34
+
35
+ def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
36
+ z = s - H*y
37
+ denom = z.dot(y)
38
+
39
+ z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
40
+ y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
41
+
42
+ # if y_norm*z_norm < tol: return H
43
+
44
+ # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
45
+ if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
46
+ H += (z*z).div_(safe_clip(denom))
47
+ return H
48
+ class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
49
+ """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
50
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
51
+ return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
52
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
53
+ return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
54
+
55
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
56
+
57
+
58
+
59
+ # Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
60
+ def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
61
+ denom = safe_clip((s**4).sum())
62
+ num = s.dot(y) - (s*B).dot(s)
63
+ B += s**2 * (num/denom)
64
+ return B
65
+
66
+ class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
67
+ """Diagonal quasi-cauchi method.
68
+
69
+ Reference:
70
+ Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
71
+ """
72
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
73
+ return diagonal_qc_B_(B=B, s=s, y=y)
74
+
75
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
76
+
77
+ # Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
78
+ def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
79
+ E_sq = s**2 * B**2
80
+ denom = safe_clip((s*E_sq).dot(s))
81
+ num = s.dot(y) - (s*B).dot(s)
82
+ B += E_sq * (num/denom)
83
+ return B
84
+
85
+ class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
86
+ """Diagonal quasi-cauchi method.
87
+
88
+ Reference:
89
+ Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
90
+ """
91
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
92
+ return diagonal_wqc_B_(B=B, s=s, y=y)
93
+
94
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
95
+
96
+ def _truncate(B: torch.Tensor, lb, ub):
97
+ return torch.where((B>lb).logical_and(B<ub), B, 1)
98
+
99
+ # Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
100
+ def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
101
+ denom = safe_clip((s**4).sum())
102
+ num = s.dot(y) + s.dot(s) - (s*B).dot(s)
103
+ B += s**2 * (num/denom) - 1
104
+ return B
105
+
106
+ class DNRTR(HessianUpdateStrategy):
107
+ """Diagonal quasi-newton method.
108
+
109
+ Reference:
110
+ Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
111
+ """
112
+ def __init__(
113
+ self,
114
+ lb: float = 1e-2,
115
+ ub: float = 1e5,
116
+ init_scale: float | Literal["auto"] = "auto",
117
+ tol: float = 1e-32,
118
+ ptol: float | None = 1e-32,
119
+ ptol_restart: bool = False,
120
+ gtol: float | None = 1e-32,
121
+ restart_interval: int | None | Literal['auto'] = None,
122
+ beta: float | None = None,
123
+ update_freq: int = 1,
124
+ scale_first: bool = False,
125
+ concat_params: bool = True,
126
+ inner: Chainable | None = None,
127
+ ):
128
+ defaults = dict(lb=lb, ub=ub)
129
+ super().__init__(
130
+ defaults=defaults,
131
+ init_scale=init_scale,
132
+ tol=tol,
133
+ ptol=ptol,
134
+ ptol_restart=ptol_restart,
135
+ gtol=gtol,
136
+ restart_interval=restart_interval,
137
+ beta=beta,
138
+ update_freq=update_freq,
139
+ scale_first=scale_first,
140
+ concat_params=concat_params,
141
+ inverse=False,
142
+ inner=inner,
143
+ )
144
+
145
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
146
+ return diagonal_wqc_B_(B=B, s=s, y=y)
147
+
148
+ def modify_B(self, B, state, setting):
149
+ return _truncate(B, setting['lb'], setting['ub'])
150
+
151
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
152
+
153
+ # Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
154
+ def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
155
+ denom = safe_clip((s**4).sum())
156
+ num = s.dot(y)
157
+ B += s**2 * (num/denom)
158
+ return B
159
+
160
+ class NewDQN(DNRTR):
161
+ """Diagonal quasi-newton method.
162
+
163
+ Reference:
164
+ Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
165
+ """
166
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
167
+ return new_dqn_B_(B=B, s=s, y=y)