torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -0,0 +1,105 @@
1
+ import math
2
+ from typing import Literal, Protocol, overload
3
+
4
+ import torch
5
+
6
+ from ...utils import TensorList
7
+ from ...utils.linalg.linear_operator import DenseInverse, LinearOperator
8
+ from ..functional import safe_clip
9
+
10
+
11
+ class DampingStrategy(Protocol):
12
+ def __call__(
13
+ self,
14
+ s: torch.Tensor,
15
+ y: torch.Tensor,
16
+ g: torch.Tensor,
17
+ H: LinearOperator,
18
+ ) -> tuple[torch.Tensor, torch.Tensor]:
19
+ return s, y
20
+
21
+ def _sy_Hs_sHs(s:torch.Tensor, y:torch.Tensor, H:LinearOperator):
22
+ if isinstance(H, DenseInverse):
23
+ Hs = H.solve(y)
24
+ sHs = y.dot(Hs)
25
+ else:
26
+ Hs = H.matvec(s)
27
+ sHs = s.dot(Hs)
28
+
29
+ return s.dot(y), Hs, sHs
30
+
31
+
32
+
33
+ def powell_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u=0.2):
34
+ # here H is hessian! not the inverse
35
+
36
+ sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
37
+ if sy < u*sHs:
38
+ phi = ((1-u) * sHs) / safe_clip((sHs - sy))
39
+ s = phi * s + (1 - phi) * Hs
40
+
41
+ return s, y
42
+
43
+ def double_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u1=0.2, u2=1/3):
44
+ # Goldfarb, Donald, Yi Ren, and Achraf Bahamou. "Practical quasi-newton methods for training deep neural networks." Advances in Neural Information Processing Systems 33 (2020): 2386-2396.
45
+
46
+ # Powell’s damping on H
47
+ sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
48
+ if sy < u1*sHs:
49
+ phi = ((1-u1) * sHs) / safe_clip(sHs - sy)
50
+ s = phi * s + (1 - phi) * Hs
51
+
52
+ # Powell’s damping with B = I
53
+ sy = s.dot(y)
54
+ ss = s.dot(s)
55
+
56
+ if sy < u2*ss:
57
+ phi = ((1-u2) * ss) / safe_clip(ss - sy)
58
+ y = phi * y + (1 - phi) * s
59
+
60
+ return s, y
61
+
62
+
63
+
64
+ _DAMPING_KEYS = Literal["powell", "double"]
65
+ _DAMPING_STRATEGIES: dict[_DAMPING_KEYS, DampingStrategy] = {
66
+ "powell": powell_damping,
67
+ "double": double_damping,
68
+ }
69
+
70
+
71
+ DampingStrategyType = _DAMPING_KEYS | DampingStrategy | None
72
+
73
+ @overload
74
+ def apply_damping(
75
+ strategy: DampingStrategyType,
76
+ s: torch.Tensor,
77
+ y: torch.Tensor,
78
+ g: torch.Tensor,
79
+ H: LinearOperator,
80
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
81
+ @overload
82
+ def apply_damping(
83
+ strategy: DampingStrategyType,
84
+ s: TensorList,
85
+ y: TensorList,
86
+ g: TensorList,
87
+ H: LinearOperator,
88
+ ) -> tuple[TensorList, TensorList]: ...
89
+ def apply_damping(
90
+ strategy: DampingStrategyType,
91
+ s,
92
+ y,
93
+ g,
94
+ H: LinearOperator,
95
+ ):
96
+ if strategy is None: return s, y
97
+ if isinstance(strategy, str): strategy = _DAMPING_STRATEGIES[strategy]
98
+
99
+ if isinstance(s, TensorList):
100
+ assert isinstance(y, TensorList) and isinstance(g, TensorList)
101
+ s_vec, y_vec = strategy(s.to_vec(), y.to_vec(), g.to_vec(), H)
102
+ return s.from_vec(s_vec), y.from_vec(y_vec)
103
+
104
+ assert isinstance(y, torch.Tensor) and isinstance(g, torch.Tensor)
105
+ return strategy(s, y, g, H)
@@ -1,163 +1,167 @@
1
- from collections.abc import Callable
2
-
3
- import torch
4
-
5
- from .quasi_newton import (
6
- HessianUpdateStrategy,
7
- _HessianUpdateStrategyDefaults,
8
- _InverseHessianUpdateStrategyDefaults,
9
- _safe_clip,
10
- )
11
-
12
-
13
- def _diag_Bv(self: HessianUpdateStrategy):
14
- B, is_inverse = self.get_B()
15
-
16
- if is_inverse:
17
- H=B
18
- def Hxv(v): return v/H
19
- return Hxv
20
-
21
- def Bv(v): return B*v
22
- return Bv
23
-
24
- def _diag_Hv(self: HessianUpdateStrategy):
25
- H, is_inverse = self.get_H()
26
-
27
- if is_inverse:
28
- B=H
29
- def Bxv(v): return v/B
30
- return Bxv
31
-
32
- def Hv(v): return H*v
33
- return Hv
34
-
35
- def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
36
- sy = s.dot(y)
37
- if sy < tol: return H
38
-
39
- sy_sq = _safe_clip(sy**2)
40
-
41
- num1 = (sy + (y * H * y)) * s*s
42
- term1 = num1.div_(sy_sq)
43
- num2 = (H * y * s).add_(s * y * H)
44
- term2 = num2.div_(sy)
45
- H += term1.sub_(term2)
46
- return H
47
-
48
- class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
49
- """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
50
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
51
- return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
52
-
53
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
54
- def make_Bv(self): return _diag_Bv(self)
55
- def make_Hv(self): return _diag_Hv(self)
56
-
57
- def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
58
- z = s - H*y
59
- denom = z.dot(y)
60
-
61
- z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
62
- y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
63
-
64
- # if y_norm*z_norm < tol: return H
65
-
66
- # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
67
- if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
68
- H += (z*z).div_(_safe_clip(denom))
69
- return H
70
- class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
71
- """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
72
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
73
- return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
74
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
75
- return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
76
-
77
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
78
- def make_Bv(self): return _diag_Bv(self)
79
- def make_Hv(self): return _diag_Hv(self)
80
-
81
-
82
-
83
- # Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
84
- def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
85
- denom = _safe_clip((s**4).sum())
86
- num = s.dot(y) - (s*B).dot(s)
87
- B += s**2 * (num/denom)
88
- return B
89
-
90
- class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
91
- """Diagonal quasi-cauchi method.
92
-
93
- Reference:
94
- Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. 1999. – Т. 9. – №. 4. – С. 1192-1204.
95
- """
96
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
97
- return diagonal_qc_B_(B=B, s=s, y=y)
98
-
99
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
100
- def make_Bv(self): return _diag_Bv(self)
101
- def make_Hv(self): return _diag_Hv(self)
102
-
103
- # Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
104
- def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
105
- E_sq = s**2 * B**2
106
- denom = _safe_clip((s*E_sq).dot(s))
107
- num = s.dot(y) - (s*B).dot(s)
108
- B += E_sq * (num/denom)
109
- return B
110
-
111
- class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
112
- """Diagonal quasi-cauchi method.
113
-
114
- Reference:
115
- Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
116
- """
117
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
118
- return diagonal_wqc_B_(B=B, s=s, y=y)
119
-
120
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
121
- def make_Bv(self): return _diag_Bv(self)
122
- def make_Hv(self): return _diag_Hv(self)
123
-
124
-
125
- # Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
126
- def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
127
- denom = _safe_clip((s**4).sum())
128
- num = s.dot(y) + s.dot(s) - (s*B).dot(s)
129
- B += s**2 * (num/denom) - 1
130
- return B
131
-
132
- class DNRTR(_HessianUpdateStrategyDefaults):
133
- """Diagonal quasi-newton method.
134
-
135
- Reference:
136
- Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
137
- """
138
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
139
- return diagonal_wqc_B_(B=B, s=s, y=y)
140
-
141
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
142
- def make_Bv(self): return _diag_Bv(self)
143
- def make_Hv(self): return _diag_Hv(self)
144
-
145
- # Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
146
- def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
147
- denom = _safe_clip((s**4).sum())
148
- num = s.dot(y)
149
- B += s**2 * (num/denom)
150
- return B
151
-
152
- class NewDQN(_HessianUpdateStrategyDefaults):
153
- """Diagonal quasi-newton method.
154
-
155
- Reference:
156
- Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
157
- """
158
- def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
159
- return new_dqn_B_(B=B, s=s, y=y)
160
-
161
- def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
162
- def make_Bv(self): return _diag_Bv(self)
163
- def make_Hv(self): return _diag_Hv(self)
1
+ from typing import Any, Literal
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+ from ...core import Chainable
6
+ from .quasi_newton import (
7
+ HessianUpdateStrategy,
8
+ _HessianUpdateStrategyDefaults,
9
+ _InverseHessianUpdateStrategyDefaults,
10
+ )
11
+
12
+ from ..functional import safe_clip
13
+
14
+
15
+ def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
16
+ sy = s.dot(y)
17
+ if sy < tol: return H
18
+
19
+ sy_sq = safe_clip(sy**2)
20
+
21
+ num1 = (sy + (y * H * y)) * s*s
22
+ term1 = num1.div_(sy_sq)
23
+ num2 = (H * y * s).add_(s * y * H)
24
+ term2 = num2.div_(sy)
25
+ H += term1.sub_(term2)
26
+ return H
27
+
28
+ class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
29
+ """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
30
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
31
+ return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
32
+
33
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
34
+
35
+ def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
36
+ z = s - H*y
37
+ denom = z.dot(y)
38
+
39
+ z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
40
+ y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
41
+
42
+ # if y_norm*z_norm < tol: return H
43
+
44
+ # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
45
+ if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
46
+ H += (z*z).div_(safe_clip(denom))
47
+ return H
48
+ class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
49
+ """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
50
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
51
+ return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
52
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
53
+ return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
54
+
55
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
56
+
57
+
58
+
59
+ # Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
60
+ def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
61
+ denom = safe_clip((s**4).sum())
62
+ num = s.dot(y) - (s*B).dot(s)
63
+ B += s**2 * (num/denom)
64
+ return B
65
+
66
+ class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
67
+ """Diagonal quasi-cauchi method.
68
+
69
+ Reference:
70
+ Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
71
+ """
72
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
73
+ return diagonal_qc_B_(B=B, s=s, y=y)
74
+
75
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
76
+
77
+ # Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
78
+ def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
79
+ E_sq = s**2 * B**2
80
+ denom = safe_clip((s*E_sq).dot(s))
81
+ num = s.dot(y) - (s*B).dot(s)
82
+ B += E_sq * (num/denom)
83
+ return B
84
+
85
+ class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
86
+ """Diagonal quasi-cauchi method.
87
+
88
+ Reference:
89
+ Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
90
+ """
91
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
92
+ return diagonal_wqc_B_(B=B, s=s, y=y)
93
+
94
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
95
+
96
+ def _truncate(B: torch.Tensor, lb, ub):
97
+ return torch.where((B>lb).logical_and(B<ub), B, 1)
98
+
99
+ # Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
100
+ def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
101
+ denom = safe_clip((s**4).sum())
102
+ num = s.dot(y) + s.dot(s) - (s*B).dot(s)
103
+ B += s**2 * (num/denom) - 1
104
+ return B
105
+
106
+ class DNRTR(HessianUpdateStrategy):
107
+ """Diagonal quasi-newton method.
108
+
109
+ Reference:
110
+ Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
111
+ """
112
+ def __init__(
113
+ self,
114
+ lb: float = 1e-2,
115
+ ub: float = 1e5,
116
+ init_scale: float | Literal["auto"] = "auto",
117
+ tol: float = 1e-32,
118
+ ptol: float | None = 1e-32,
119
+ ptol_restart: bool = False,
120
+ gtol: float | None = 1e-32,
121
+ restart_interval: int | None | Literal['auto'] = None,
122
+ beta: float | None = None,
123
+ update_freq: int = 1,
124
+ scale_first: bool = False,
125
+ concat_params: bool = True,
126
+ inner: Chainable | None = None,
127
+ ):
128
+ defaults = dict(lb=lb, ub=ub)
129
+ super().__init__(
130
+ defaults=defaults,
131
+ init_scale=init_scale,
132
+ tol=tol,
133
+ ptol=ptol,
134
+ ptol_restart=ptol_restart,
135
+ gtol=gtol,
136
+ restart_interval=restart_interval,
137
+ beta=beta,
138
+ update_freq=update_freq,
139
+ scale_first=scale_first,
140
+ concat_params=concat_params,
141
+ inverse=False,
142
+ inner=inner,
143
+ )
144
+
145
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
146
+ return diagonal_wqc_B_(B=B, s=s, y=y)
147
+
148
+ def modify_B(self, B, state, setting):
149
+ return _truncate(B, setting['lb'], setting['ub'])
150
+
151
+ def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
152
+
153
+ # Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
154
+ def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
155
+ denom = safe_clip((s**4).sum())
156
+ num = s.dot(y)
157
+ B += s**2 * (num/denom)
158
+ return B
159
+
160
+ class NewDQN(DNRTR):
161
+ """Diagonal quasi-newton method.
162
+
163
+ Reference:
164
+ Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
165
+ """
166
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
167
+ return new_dqn_B_(B=B, s=s, y=y)