torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -0,0 +1,7 @@
1
+ from .restars import (
2
+ BirginMartinezRestart,
3
+ PowellRestart,
4
+ RestartEvery,
5
+ RestartOnStuck,
6
+ RestartStrategyBase,
7
+ )
@@ -0,0 +1,253 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import final, Literal, cast
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Var
8
+ from ...utils import TensorList
9
+ from ..termination import TerminationCriteriaBase
10
+
11
+ def _reset_except_self(optimizer, var, self: Module):
12
+ for m in optimizer.unrolled_modules: m.reset()
13
+
14
+ class RestartStrategyBase(Module, ABC):
15
+ """Base class for restart strategies.
16
+
17
+ On each ``update``/``step`` this checks reset condition and if it is satisfied,
18
+ resets the modules before updating or stepping.
19
+ """
20
+ def __init__(self, defaults: dict | None = None, modules: Chainable | None = None):
21
+ if defaults is None: defaults = {}
22
+ super().__init__(defaults)
23
+ if modules is not None:
24
+ self.set_child('modules', modules)
25
+
26
+ @abstractmethod
27
+ def should_reset(self, var: Var) -> bool:
28
+ """returns whether reset should occur"""
29
+
30
+ def _reset_on_condition(self, var):
31
+ modules = self.children.get('modules', None)
32
+
33
+ if self.should_reset(var):
34
+ if modules is None:
35
+ var.post_step_hooks.append(partial(_reset_except_self, self=self))
36
+ else:
37
+ modules.reset()
38
+
39
+ return modules
40
+
41
+ @final
42
+ def update(self, var):
43
+ modules = self._reset_on_condition(var)
44
+ if modules is not None:
45
+ modules.update(var)
46
+
47
+ @final
48
+ def apply(self, var):
49
+ # don't check here because it was check in `update`
50
+ modules = self.children.get('modules', None)
51
+ if modules is None: return var
52
+ return modules.apply(var.clone(clone_update=False))
53
+
54
+ @final
55
+ def step(self, var):
56
+ modules = self._reset_on_condition(var)
57
+ if modules is None: return var
58
+ return modules.step(var.clone(clone_update=False))
59
+
60
+
61
+
62
+ class RestartOnStuck(RestartStrategyBase):
63
+ """Resets the state when update (difference in parameters) is zero for multiple steps in a row.
64
+
65
+ Args:
66
+ modules (Chainable | None):
67
+ modules to reset. If None, resets all modules.
68
+ tol (float, optional):
69
+ step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)
70
+ n_tol (int, optional):
71
+ number of failed consequtive steps required to trigger a reset. Defaults to 10.
72
+
73
+ """
74
+ def __init__(self, modules: Chainable | None, tol: float | None = None, n_tol: int = 10):
75
+ defaults = dict(tol=tol, n_tol=n_tol)
76
+ super().__init__(defaults, modules)
77
+
78
+ @torch.no_grad
79
+ def should_reset(self, var):
80
+ step = self.global_state.get('step', 0)
81
+ self.global_state['step'] = step + 1
82
+
83
+ params = TensorList(var.params)
84
+ tol = self.defaults['tol']
85
+ if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
86
+ n_tol = self.defaults['n_tol']
87
+ n_bad = self.global_state.get('n_bad', 0)
88
+
89
+ # calculate difference in parameters
90
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList)
91
+ update = params - prev_params
92
+ prev_params.copy_(params)
93
+
94
+ # if update is too small, it is considered bad, otherwise n_bad is reset to 0
95
+ if step > 0:
96
+ if update.abs().global_max() <= tol:
97
+ n_bad += 1
98
+
99
+ else:
100
+ n_bad = 0
101
+
102
+ self.global_state['n_bad'] = n_bad
103
+
104
+ # no progress, reset
105
+ if n_bad >= n_tol:
106
+ self.global_state.clear()
107
+ return True
108
+
109
+ return False
110
+
111
+
112
+ class RestartEvery(RestartStrategyBase):
113
+ """Resets the state every n steps
114
+
115
+ Args:
116
+ modules (Chainable | None):
117
+ modules to reset. If None, resets all modules.
118
+ steps (int | Literal["ndim"]):
119
+ number of steps between resets. "ndim" to use number of parameters.
120
+ """
121
+ def __init__(self, modules: Chainable | None, steps: int | Literal['ndim']):
122
+ defaults = dict(steps=steps)
123
+ super().__init__(defaults, modules)
124
+
125
+ def should_reset(self, var):
126
+ step = self.global_state.get('step', 0) + 1
127
+ self.global_state['step'] = step
128
+
129
+ n = self.defaults['steps']
130
+ if isinstance(n, str): n = sum(p.numel() for p in var.params if p.requires_grad)
131
+
132
+ # reset every n steps
133
+ if step % n == 0:
134
+ self.global_state.clear()
135
+ return True
136
+
137
+ return False
138
+
139
+ class RestartOnTerminationCriteria(RestartStrategyBase):
140
+ def __init__(self, modules: Chainable | None, criteria: "TerminationCriteriaBase"):
141
+ super().__init__(None, modules)
142
+ self.set_child('criteria', criteria)
143
+
144
+ def should_reset(self, var):
145
+ criteria = cast(TerminationCriteriaBase, self.children['criteria'])
146
+ return criteria.should_terminate(var)
147
+
148
+ class PowellRestart(RestartStrategyBase):
149
+ """Powell's two restarting criterions for conjugate gradient methods.
150
+
151
+ The restart clears all states of ``modules``.
152
+
153
+ Args:
154
+ modules (Chainable | None):
155
+ modules to reset. If None, resets all modules.
156
+ cond1 (float | None, optional):
157
+ criterion that checks for nonconjugacy of the search directions.
158
+ Restart is performed whenevr g^Tg_{k+1} >= cond1*||g_{k+1}||^2.
159
+ The default condition value of 0.2 is suggested by Powell. Can be None to disable that criterion.
160
+ cond2 (float | None, optional):
161
+ criterion that checks if direction is not effectively downhill.
162
+ Restart is performed if -1.2||g||^2 < d^Tg < -0.8||g||^2.
163
+ Defaults to 0.2. Can be None to disable that criterion.
164
+
165
+ Reference:
166
+ Powell, Michael James David. "Restart procedures for the conjugate gradient method." Mathematical programming 12.1 (1977): 241-254.
167
+ """
168
+ def __init__(self, modules: Chainable | None, cond1:float | None = 0.2, cond2:float | None = 0.2):
169
+ defaults=dict(cond1=cond1, cond2=cond2)
170
+ super().__init__(defaults, modules)
171
+
172
+ def should_reset(self, var):
173
+ g = TensorList(var.get_grad())
174
+ cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
175
+
176
+ # -------------------------------- initialize -------------------------------- #
177
+ if 'initialized' not in self.global_state:
178
+ self.global_state['initialized'] = 0
179
+ g_prev = self.get_state(var.params, 'g_prev', init=g)
180
+ return False
181
+
182
+ g_g = g.dot(g)
183
+
184
+ reset = False
185
+ # ------------------------------- 1st condition ------------------------------ #
186
+ if cond1 is not None:
187
+ g_prev = self.get_state(var.params, 'g_prev', must_exist=True, cls=TensorList)
188
+ g_g_prev = g_prev.dot(g)
189
+
190
+ if g_g_prev.abs() >= cond1 * g_g:
191
+ reset = True
192
+
193
+ # ------------------------------- 2nd condition ------------------------------ #
194
+ if (cond2 is not None) and (not reset):
195
+ d_g = TensorList(var.get_update()).dot(g)
196
+ if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
197
+ reset = True
198
+
199
+ # ------------------------------ clear on reset ------------------------------ #
200
+ if reset:
201
+ self.global_state.clear()
202
+ self.clear_state_keys('g_prev')
203
+ return True
204
+
205
+ return False
206
+
207
+ # this requires direction from inner module,
208
+ # so both this and inner have to be a single module
209
+ class BirginMartinezRestart(Module):
210
+ """the restart criterion for conjugate gradient methods designed by Birgin and Martinez.
211
+
212
+ This criterion restarts when when the angle between dk+1 and −gk+1 is not acute enough.
213
+
214
+ The restart clears all states of ``module``.
215
+
216
+ Args:
217
+ module (Module):
218
+ module to restart, should be a conjugate gradient or possibly a quasi-newton method.
219
+ cond (float, optional):
220
+ Restart is performed whenevr d^Tg > -cond*||d||*||g||.
221
+ The default condition value of 1e-3 is suggested by Birgin and Martinez.
222
+
223
+ Reference:
224
+ Birgin, Ernesto G., and José Mario Martínez. "A spectral conjugate gradient method for unconstrained optimization." Applied Mathematics & Optimization 43.2 (2001): 117-128.
225
+ """
226
+ def __init__(self, module: Module, cond:float = 1e-3):
227
+ defaults=dict(cond=cond)
228
+ super().__init__(defaults)
229
+
230
+ self.set_child("module", module)
231
+
232
+ def update(self, var):
233
+ module = self.children['module']
234
+ module.update(var)
235
+
236
+ def apply(self, var):
237
+ module = self.children['module']
238
+ var = module.apply(var.clone(clone_update=False))
239
+
240
+ cond = self.defaults['cond']
241
+ g = TensorList(var.get_grad())
242
+ d = TensorList(var.get_update())
243
+ d_g = d.dot(g)
244
+ d_norm = d.global_vector_norm()
245
+ g_norm = g.global_vector_norm()
246
+
247
+ # d in our case is same direction as g so it has a minus sign
248
+ if -d_g > -cond * d_norm * g_norm:
249
+ module.reset()
250
+ var.update = g.clone()
251
+ return var
252
+
253
+ return var
@@ -1,3 +1,4 @@
1
1
  from .newton import Newton, InverseFreeNewton
2
- from .newton_cg import NewtonCG, TruncatedNewtonCG
2
+ from .newton_cg import NewtonCG, NewtonCGSteihaug
3
3
  from .nystrom import NystromSketchAndSolve, NystromPCG
4
+ from .multipoint import SixthOrder3P, SixthOrder5P, TwoPointNewton, SixthOrder3PM2
@@ -0,0 +1,238 @@
1
+ from collections.abc import Callable
2
+ from contextlib import nullcontext
3
+ from abc import ABC, abstractmethod
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, apply_transform, Var
8
+ from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
9
+ from ...utils.derivatives import (
10
+ flatten_jacobian,
11
+ jacobian_wrt,
12
+ )
13
+
14
+ class HigherOrderMethodBase(Module, ABC):
15
+ def __init__(self, defaults: dict | None = None, vectorize: bool = True):
16
+ self._vectorize = vectorize
17
+ super().__init__(defaults)
18
+
19
+ @abstractmethod
20
+ def one_iteration(
21
+ self,
22
+ x: torch.Tensor,
23
+ evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
24
+ var: Var,
25
+ ) -> torch.Tensor:
26
+ """"""
27
+
28
+ @torch.no_grad
29
+ def step(self, var):
30
+ params = TensorList(var.params)
31
+ x0 = params.clone()
32
+ closure = var.closure
33
+ if closure is None: raise RuntimeError('MultipointNewton requires closure')
34
+ vectorize = self._vectorize
35
+
36
+ def evaluate(x, order) -> tuple[torch.Tensor, ...]:
37
+ """order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
38
+ params.from_vec_(x)
39
+
40
+ if order == 0:
41
+ loss = closure(False)
42
+ params.copy_(x0)
43
+ return (loss, )
44
+
45
+ if order == 1:
46
+ with torch.enable_grad():
47
+ loss = closure()
48
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
49
+ params.copy_(x0)
50
+ return loss, torch.cat([g.ravel() for g in grad])
51
+
52
+ with torch.enable_grad():
53
+ loss = var.loss = var.loss_approx = closure(False)
54
+
55
+ g_list = torch.autograd.grad(loss, params, create_graph=True)
56
+ var.grad = list(g_list)
57
+
58
+ g = torch.cat([t.ravel() for t in g_list])
59
+ n = g.numel()
60
+ ret = [loss, g]
61
+ T = g # current derivatives tensor
62
+
63
+ # get all derivative up to order
64
+ for o in range(2, order + 1):
65
+ is_last = o == order
66
+ T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
67
+ with torch.no_grad() if is_last else nullcontext():
68
+ # the shape is (ndim, ) * order
69
+ T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
70
+ ret.append(T)
71
+
72
+ params.copy_(x0)
73
+ return tuple(ret)
74
+
75
+ x = torch.cat([p.ravel() for p in params])
76
+ dir = self.one_iteration(x, evaluate, var)
77
+ var.update = vec_to_tensors(dir, var.params)
78
+ return var
79
+
80
+ def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
81
+ if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
82
+ A_inv, info = torch.linalg.inv_ex(A) # pylint:disable=not-callable
83
+ if info == 0: return A_inv
84
+ return torch.linalg.pinv(A) # pylint:disable=not-callable
85
+
86
+ def _solve(A: torch.Tensor, b: torch.Tensor, lstsq: bool) -> torch.Tensor:
87
+ if lstsq: return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
88
+ sol, info = torch.linalg.solve_ex(A, b) # pylint:disable=not-callable
89
+ if info == 0: return sol
90
+ return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
91
+
92
+ # 3f 2J 3 solves
93
+ def sixth_order_3p(x:torch.Tensor, f, f_j, lstsq:bool=False):
94
+ f_x, J_x = f_j(x)
95
+
96
+ y = x - _solve(J_x, f_x, lstsq=lstsq)
97
+ f_y, J_y = f_j(y)
98
+
99
+ z = y - _solve(J_y, f_y, lstsq=lstsq)
100
+ f_z = f(z)
101
+
102
+ return y - _solve(J_y, f_y+f_z, lstsq=lstsq)
103
+
104
+ class SixthOrder3P(HigherOrderMethodBase):
105
+ """Sixth-order iterative method.
106
+
107
+ Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
108
+ """
109
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
110
+ defaults=dict(lstsq=lstsq)
111
+ super().__init__(defaults=defaults, vectorize=vectorize)
112
+
113
+ def one_iteration(self, x, evaluate, var):
114
+ settings = self.defaults
115
+ lstsq = settings['lstsq']
116
+ def f(x): return evaluate(x, 1)[1]
117
+ def f_j(x): return evaluate(x, 2)[1:]
118
+ x_star = sixth_order_3p(x, f, f_j, lstsq)
119
+ return x - x_star
120
+
121
+ # I don't think it works (I tested root finding with this and it goes all over the place)
122
+ # I double checked it multiple times
123
+ # def sixth_order_im1(x:torch.Tensor, f, f_j, lstsq:bool=False):
124
+ # f_x, J_x = f_j(x)
125
+ # J_x_inv = _inv(J_x, lstsq=lstsq)
126
+
127
+ # y = x - J_x_inv @ f_x
128
+ # f_y, J_y = f_j(y)
129
+
130
+ # z = x - 2 * _solve(J_x + J_y, f_x, lstsq=lstsq)
131
+ # f_z = f(z)
132
+
133
+ # I = torch.eye(J_y.size(0), device=J_y.device, dtype=J_y.dtype)
134
+ # term1 = (7/2)*I
135
+ # term2 = 4 * (J_x_inv@J_y)
136
+ # term3 = (3/2) * (J_x_inv @ (J_y@J_y))
137
+
138
+ # return z - (term1 - term2 + term3) @ J_x_inv @ f_z
139
+
140
+ # class SixthOrderIM1(HigherOrderMethodBase):
141
+ # """sixth-order iterative method https://www.mdpi.com/2504-3110/8/3/133
142
+
143
+ # """
144
+ # def __init__(self, lstsq: bool=False, vectorize: bool = True):
145
+ # defaults=dict(lstsq=lstsq)
146
+ # super().__init__(defaults=defaults, vectorize=vectorize)
147
+
148
+ # def iteration(self, x, evaluate, var):
149
+ # settings = self.defaults
150
+ # lstsq = settings['lstsq']
151
+ # def f(x): return evaluate(x, 1)[1]
152
+ # def f_j(x): return evaluate(x, 2)[1:]
153
+ # x_star = sixth_order_im1(x, f, f_j, lstsq)
154
+ # return x - x_star
155
+
156
+ # 5f 5J 3 solves
157
+ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
158
+ f_x, J_x = f_j(x)
159
+ y = x - _solve(J_x, f_x, lstsq=lstsq)
160
+
161
+ f_y, J_y = f_j(y)
162
+ f_xy2, J_xy2 = f_j((x + y) / 2)
163
+
164
+ A = J_x + 2*J_xy2 + J_y
165
+
166
+ z = y - 4*_solve(A, f_y, lstsq=lstsq)
167
+ f_z, J_z = f_j(z)
168
+
169
+ f_xz2, J_xz2 = f_j((x + z) / 2)
170
+ B = J_x + 2*J_xz2 + J_z
171
+
172
+ return z - 4*_solve(B, f_z, lstsq=lstsq)
173
+
174
+ class SixthOrder5P(HigherOrderMethodBase):
175
+ """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
176
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
177
+ defaults=dict(lstsq=lstsq)
178
+ super().__init__(defaults=defaults, vectorize=vectorize)
179
+
180
+ def one_iteration(self, x, evaluate, var):
181
+ settings = self.defaults
182
+ lstsq = settings['lstsq']
183
+ def f_j(x): return evaluate(x, 2)[1:]
184
+ x_star = sixth_order_5p(x, f_j, lstsq)
185
+ return x - x_star
186
+
187
+ # 2f 1J 2 solves
188
+ def two_point_newton(x: torch.Tensor, f, f_j, lstsq:bool=False):
189
+ """third order convergence"""
190
+ f_x, J_x = f_j(x)
191
+ y = x - _solve(J_x, f_x, lstsq=lstsq)
192
+ f_y = f(y)
193
+ return x - _solve(J_x, f_x + f_y, lstsq=lstsq)
194
+
195
+ class TwoPointNewton(HigherOrderMethodBase):
196
+ """two-point Newton method with frozen derivative with third order convergence.
197
+
198
+ Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
199
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
200
+ defaults=dict(lstsq=lstsq)
201
+ super().__init__(defaults=defaults, vectorize=vectorize)
202
+
203
+ def one_iteration(self, x, evaluate, var):
204
+ settings = self.defaults
205
+ lstsq = settings['lstsq']
206
+ def f(x): return evaluate(x, 1)[1]
207
+ def f_j(x): return evaluate(x, 2)[1:]
208
+ x_star = two_point_newton(x, f, f_j, lstsq)
209
+ return x - x_star
210
+
211
+ #3f 2J 1inv
212
+ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
213
+ f_x, J_x = f_j(x)
214
+ J_x_inv = _inv(J_x, lstsq=lstsq)
215
+ y = x - J_x_inv @ f_x
216
+ f_y, J_y = f_j(y)
217
+
218
+ I = torch.eye(x.numel(), dtype=x.dtype, device=x.device)
219
+ term = (2*I - J_x_inv @ J_y) @ J_x_inv
220
+ z = y - term @ f_y
221
+
222
+ return z - term @ f(z)
223
+
224
+
225
+ class SixthOrder3PM2(HigherOrderMethodBase):
226
+ """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
227
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
228
+ defaults=dict(lstsq=lstsq)
229
+ super().__init__(defaults=defaults, vectorize=vectorize)
230
+
231
+ def one_iteration(self, x, evaluate, var):
232
+ settings = self.defaults
233
+ lstsq = settings['lstsq']
234
+ def f_j(x): return evaluate(x, 2)[1:]
235
+ def f(x): return evaluate(x, 1)[1]
236
+ x_star = sixth_order_3pm2(x, f, f_j, lstsq)
237
+ return x - x_star
238
+