torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,7 @@
1
+ from .restars import (
2
+ BirginMartinezRestart,
3
+ PowellRestart,
4
+ RestartEvery,
5
+ RestartOnStuck,
6
+ RestartStrategyBase,
7
+ )
@@ -0,0 +1,252 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import final, Literal, cast
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Var
8
+ from ...utils import TensorList
9
+ from ..termination import TerminationCriteriaBase
10
+
11
+ def _reset_except_self(optimizer, var, self: Module):
12
+ for m in optimizer.unrolled_modules: m.reset()
13
+
14
+ class RestartStrategyBase(Module, ABC):
15
+ """Base class for restart strategies.
16
+
17
+ On each ``update``/``step`` this checks reset condition and if it is satisfied,
18
+ resets the modules before updating or stepping.
19
+ """
20
+ def __init__(self, defaults: dict | None = None, modules: Chainable | None = None):
21
+ if defaults is None: defaults = {}
22
+ super().__init__(defaults)
23
+ if modules is not None:
24
+ self.set_child('modules', modules)
25
+
26
+ @abstractmethod
27
+ def should_reset(self, var: Var) -> bool:
28
+ """returns whether reset should occur"""
29
+
30
+ def _reset_on_condition(self, var):
31
+ modules = self.children.get('modules', None)
32
+
33
+ if self.should_reset(var):
34
+ if modules is None:
35
+ var.post_step_hooks.append(partial(_reset_except_self, self=self))
36
+ else:
37
+ modules.reset()
38
+
39
+ return modules
40
+
41
+ @final
42
+ def update(self, var):
43
+ modules = self._reset_on_condition(var)
44
+ if modules is not None:
45
+ modules.update(var)
46
+
47
+ @final
48
+ def apply(self, var):
49
+ # don't check here because it was check in `update`
50
+ modules = self.children.get('modules', None)
51
+ if modules is None: return var
52
+ return modules.apply(var.clone(clone_update=False))
53
+
54
+ @final
55
+ def step(self, var):
56
+ modules = self._reset_on_condition(var)
57
+ if modules is None: return var
58
+ return modules.step(var.clone(clone_update=False))
59
+
60
+
61
+
62
+ class RestartOnStuck(RestartStrategyBase):
63
+ """Resets the state when update (difference in parameters) is close to zero for multiple steps in a row.
64
+
65
+ Args:
66
+ modules (Chainable | None):
67
+ modules to reset. If None, resets all modules.
68
+ tol (float, optional):
69
+ step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to 1e-10.
70
+ n_tol (int, optional):
71
+ number of failed consequtive steps required to trigger a reset. Defaults to 4.
72
+
73
+ """
74
+ def __init__(self, modules: Chainable | None, tol: float = 1e-10, n_tol: int = 4):
75
+ defaults = dict(tol=tol, n_tol=n_tol)
76
+ super().__init__(defaults, modules)
77
+
78
+ @torch.no_grad
79
+ def should_reset(self, var):
80
+ step = self.global_state.get('step', 0)
81
+ self.global_state['step'] = step + 1
82
+
83
+ params = TensorList(var.params)
84
+ tol = self.defaults['tol']
85
+ n_tol = self.defaults['n_tol']
86
+ n_bad = self.global_state.get('n_bad', 0)
87
+
88
+ # calculate difference in parameters
89
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList)
90
+ update = params - prev_params
91
+ prev_params.copy_(params)
92
+
93
+ # if update is too small, it is considered bad, otherwise n_bad is reset to 0
94
+ if step > 0:
95
+ if update.abs().global_max() <= tol:
96
+ n_bad += 1
97
+
98
+ else:
99
+ n_bad = 0
100
+
101
+ self.global_state['n_bad'] = n_bad
102
+
103
+ # no progress, reset
104
+ if n_bad >= n_tol:
105
+ self.global_state.clear()
106
+ return True
107
+
108
+ return False
109
+
110
+
111
+ class RestartEvery(RestartStrategyBase):
112
+ """Resets the state every n steps
113
+
114
+ Args:
115
+ modules (Chainable | None):
116
+ modules to reset. If None, resets all modules.
117
+ steps (int | Literal["ndim"]):
118
+ number of steps between resets. "ndim" to use number of parameters.
119
+ """
120
+ def __init__(self, modules: Chainable | None, steps: int | Literal['ndim']):
121
+ defaults = dict(steps=steps)
122
+ super().__init__(defaults, modules)
123
+
124
+ def should_reset(self, var):
125
+ step = self.global_state.get('step', 0) + 1
126
+ self.global_state['step'] = step
127
+
128
+ n = self.defaults['steps']
129
+ if isinstance(n, str): n = sum(p.numel() for p in var.params if p.requires_grad)
130
+
131
+ # reset every n steps
132
+ if step % n == 0:
133
+ self.global_state.clear()
134
+ return True
135
+
136
+ return False
137
+
138
+ class RestartOnTerminationCriteria(RestartStrategyBase):
139
+ def __init__(self, modules: Chainable | None, criteria: "TerminationCriteriaBase"):
140
+ super().__init__(None, modules)
141
+ self.set_child('criteria', criteria)
142
+
143
+ def should_reset(self, var):
144
+ criteria = cast(TerminationCriteriaBase, self.children['criteria'])
145
+ return criteria.should_terminate(var)
146
+
147
+ class PowellRestart(RestartStrategyBase):
148
+ """Powell's two restarting criterions for conjugate gradient methods.
149
+
150
+ The restart clears all states of ``modules``.
151
+
152
+ Args:
153
+ modules (Chainable | None):
154
+ modules to reset. If None, resets all modules.
155
+ cond1 (float | None, optional):
156
+ criterion that checks for nonconjugacy of the search directions.
157
+ Restart is performed whenevr g^Tg_{k+1} >= cond1*||g_{k+1}||^2.
158
+ The default condition value of 0.2 is suggested by Powell. Can be None to disable that criterion.
159
+ cond2 (float | None, optional):
160
+ criterion that checks if direction is not effectively downhill.
161
+ Restart is performed if -1.2||g||^2 < d^Tg < -0.8||g||^2.
162
+ Defaults to 0.2. Can be None to disable that criterion.
163
+
164
+ Reference:
165
+ Powell, Michael James David. "Restart procedures for the conjugate gradient method." Mathematical programming 12.1 (1977): 241-254.
166
+ """
167
+ def __init__(self, modules: Chainable | None, cond1:float | None = 0.2, cond2:float | None = 0.2):
168
+ defaults=dict(cond1=cond1, cond2=cond2)
169
+ super().__init__(defaults, modules)
170
+
171
+ def should_reset(self, var):
172
+ g = TensorList(var.get_grad())
173
+ cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
174
+
175
+ # -------------------------------- initialize -------------------------------- #
176
+ if 'initialized' not in self.global_state:
177
+ self.global_state['initialized'] = 0
178
+ g_prev = self.get_state(var.params, 'g_prev', init=g)
179
+ return False
180
+
181
+ g_g = g.dot(g)
182
+
183
+ reset = False
184
+ # ------------------------------- 1st condition ------------------------------ #
185
+ if cond1 is not None:
186
+ g_prev = self.get_state(var.params, 'g_prev', must_exist=True, cls=TensorList)
187
+ g_g_prev = g_prev.dot(g)
188
+
189
+ if g_g_prev.abs() >= cond1 * g_g:
190
+ reset = True
191
+
192
+ # ------------------------------- 2nd condition ------------------------------ #
193
+ if (cond2 is not None) and (not reset):
194
+ d_g = TensorList(var.get_update()).dot(g)
195
+ if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
196
+ reset = True
197
+
198
+ # ------------------------------ clear on reset ------------------------------ #
199
+ if reset:
200
+ self.global_state.clear()
201
+ self.clear_state_keys('g_prev')
202
+ return True
203
+
204
+ return False
205
+
206
+ # this requires direction from inner module,
207
+ # so both this and inner have to be a single module
208
+ class BirginMartinezRestart(Module):
209
+ """the restart criterion for conjugate gradient methods designed by Birgin and Martinez.
210
+
211
+ This criterion restarts when when the angle between dk+1 and −gk+1 is not acute enough.
212
+
213
+ The restart clears all states of ``module``.
214
+
215
+ Args:
216
+ module (Module):
217
+ module to restart, should be a conjugate gradient or possibly a quasi-newton method.
218
+ cond (float, optional):
219
+ Restart is performed whenevr d^Tg > -cond*||d||*||g||.
220
+ The default condition value of 1e-3 is suggested by Birgin and Martinez.
221
+
222
+ Reference:
223
+ Birgin, Ernesto G., and José Mario Martínez. "A spectral conjugate gradient method for unconstrained optimization." Applied Mathematics & Optimization 43.2 (2001): 117-128.
224
+ """
225
+ def __init__(self, module: Module, cond:float = 1e-3):
226
+ defaults=dict(cond=cond)
227
+ super().__init__(defaults)
228
+
229
+ self.set_child("module", module)
230
+
231
+ def update(self, var):
232
+ module = self.children['module']
233
+ module.update(var)
234
+
235
+ def apply(self, var):
236
+ module = self.children['module']
237
+ var = module.apply(var.clone(clone_update=False))
238
+
239
+ cond = self.defaults['cond']
240
+ g = TensorList(var.get_grad())
241
+ d = TensorList(var.get_update())
242
+ d_g = d.dot(g)
243
+ d_norm = d.global_vector_norm()
244
+ g_norm = g.global_vector_norm()
245
+
246
+ # d in our case is same direction as g so it has a minus sign
247
+ if -d_g > -cond * d_norm * g_norm:
248
+ module.reset()
249
+ var.update = g.clone()
250
+ return var
251
+
252
+ return var
@@ -1,3 +1,4 @@
1
- from .newton import Newton
2
- from .newton_cg import NewtonCG
1
+ from .newton import Newton, InverseFreeNewton
2
+ from .newton_cg import NewtonCG, NewtonCGSteihaug
3
3
  from .nystrom import NystromSketchAndSolve, NystromPCG
4
+ from .multipoint import SixthOrder3P, SixthOrder5P, TwoPointNewton, SixthOrder3PM2
@@ -0,0 +1,238 @@
1
+ from collections.abc import Callable
2
+ from contextlib import nullcontext
3
+ from abc import ABC, abstractmethod
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, apply_transform, Var
8
+ from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
9
+ from ...utils.derivatives import (
10
+ flatten_jacobian,
11
+ jacobian_wrt,
12
+ )
13
+
14
+ class HigherOrderMethodBase(Module, ABC):
15
+ def __init__(self, defaults: dict | None = None, vectorize: bool = True):
16
+ self._vectorize = vectorize
17
+ super().__init__(defaults)
18
+
19
+ @abstractmethod
20
+ def one_iteration(
21
+ self,
22
+ x: torch.Tensor,
23
+ evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
24
+ var: Var,
25
+ ) -> torch.Tensor:
26
+ """"""
27
+
28
+ @torch.no_grad
29
+ def step(self, var):
30
+ params = TensorList(var.params)
31
+ x0 = params.clone()
32
+ closure = var.closure
33
+ if closure is None: raise RuntimeError('MultipointNewton requires closure')
34
+ vectorize = self._vectorize
35
+
36
+ def evaluate(x, order) -> tuple[torch.Tensor, ...]:
37
+ """order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
38
+ params.from_vec_(x)
39
+
40
+ if order == 0:
41
+ loss = closure(False)
42
+ params.copy_(x0)
43
+ return (loss, )
44
+
45
+ if order == 1:
46
+ with torch.enable_grad():
47
+ loss = closure()
48
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
49
+ params.copy_(x0)
50
+ return loss, torch.cat([g.ravel() for g in grad])
51
+
52
+ with torch.enable_grad():
53
+ loss = var.loss = var.loss_approx = closure(False)
54
+
55
+ g_list = torch.autograd.grad(loss, params, create_graph=True)
56
+ var.grad = list(g_list)
57
+
58
+ g = torch.cat([t.ravel() for t in g_list])
59
+ n = g.numel()
60
+ ret = [loss, g]
61
+ T = g # current derivatives tensor
62
+
63
+ # get all derivative up to order
64
+ for o in range(2, order + 1):
65
+ is_last = o == order
66
+ T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
67
+ with torch.no_grad() if is_last else nullcontext():
68
+ # the shape is (ndim, ) * order
69
+ T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
70
+ ret.append(T)
71
+
72
+ params.copy_(x0)
73
+ return tuple(ret)
74
+
75
+ x = torch.cat([p.ravel() for p in params])
76
+ dir = self.one_iteration(x, evaluate, var)
77
+ var.update = vec_to_tensors(dir, var.params)
78
+ return var
79
+
80
+ def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
81
+ if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
82
+ A_inv, info = torch.linalg.inv_ex(A) # pylint:disable=not-callable
83
+ if info == 0: return A_inv
84
+ return torch.linalg.pinv(A) # pylint:disable=not-callable
85
+
86
+ def _solve(A: torch.Tensor, b: torch.Tensor, lstsq: bool) -> torch.Tensor:
87
+ if lstsq: return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
88
+ sol, info = torch.linalg.solve_ex(A, b) # pylint:disable=not-callable
89
+ if info == 0: return sol
90
+ return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
91
+
92
+ # 3f 2J 3 solves
93
+ def sixth_order_3p(x:torch.Tensor, f, f_j, lstsq:bool=False):
94
+ f_x, J_x = f_j(x)
95
+
96
+ y = x - _solve(J_x, f_x, lstsq=lstsq)
97
+ f_y, J_y = f_j(y)
98
+
99
+ z = y - _solve(J_y, f_y, lstsq=lstsq)
100
+ f_z = f(z)
101
+
102
+ return y - _solve(J_y, f_y+f_z, lstsq=lstsq)
103
+
104
+ class SixthOrder3P(HigherOrderMethodBase):
105
+ """Sixth-order iterative method.
106
+
107
+ Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
108
+ """
109
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
110
+ defaults=dict(lstsq=lstsq)
111
+ super().__init__(defaults=defaults, vectorize=vectorize)
112
+
113
+ def one_iteration(self, x, evaluate, var):
114
+ settings = self.defaults
115
+ lstsq = settings['lstsq']
116
+ def f(x): return evaluate(x, 1)[1]
117
+ def f_j(x): return evaluate(x, 2)[1:]
118
+ x_star = sixth_order_3p(x, f, f_j, lstsq)
119
+ return x - x_star
120
+
121
+ # I don't think it works (I tested root finding with this and it goes all over the place)
122
+ # I double checked it multiple times
123
+ # def sixth_order_im1(x:torch.Tensor, f, f_j, lstsq:bool=False):
124
+ # f_x, J_x = f_j(x)
125
+ # J_x_inv = _inv(J_x, lstsq=lstsq)
126
+
127
+ # y = x - J_x_inv @ f_x
128
+ # f_y, J_y = f_j(y)
129
+
130
+ # z = x - 2 * _solve(J_x + J_y, f_x, lstsq=lstsq)
131
+ # f_z = f(z)
132
+
133
+ # I = torch.eye(J_y.size(0), device=J_y.device, dtype=J_y.dtype)
134
+ # term1 = (7/2)*I
135
+ # term2 = 4 * (J_x_inv@J_y)
136
+ # term3 = (3/2) * (J_x_inv @ (J_y@J_y))
137
+
138
+ # return z - (term1 - term2 + term3) @ J_x_inv @ f_z
139
+
140
+ # class SixthOrderIM1(HigherOrderMethodBase):
141
+ # """sixth-order iterative method https://www.mdpi.com/2504-3110/8/3/133
142
+
143
+ # """
144
+ # def __init__(self, lstsq: bool=False, vectorize: bool = True):
145
+ # defaults=dict(lstsq=lstsq)
146
+ # super().__init__(defaults=defaults, vectorize=vectorize)
147
+
148
+ # def iteration(self, x, evaluate, var):
149
+ # settings = self.defaults
150
+ # lstsq = settings['lstsq']
151
+ # def f(x): return evaluate(x, 1)[1]
152
+ # def f_j(x): return evaluate(x, 2)[1:]
153
+ # x_star = sixth_order_im1(x, f, f_j, lstsq)
154
+ # return x - x_star
155
+
156
+ # 5f 5J 3 solves
157
+ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
158
+ f_x, J_x = f_j(x)
159
+ y = x - _solve(J_x, f_x, lstsq=lstsq)
160
+
161
+ f_y, J_y = f_j(y)
162
+ f_xy2, J_xy2 = f_j((x + y) / 2)
163
+
164
+ A = J_x + 2*J_xy2 + J_y
165
+
166
+ z = y - 4*_solve(A, f_y, lstsq=lstsq)
167
+ f_z, J_z = f_j(z)
168
+
169
+ f_xz2, J_xz2 = f_j((x + z) / 2)
170
+ B = J_x + 2*J_xz2 + J_z
171
+
172
+ return z - 4*_solve(B, f_z, lstsq=lstsq)
173
+
174
+ class SixthOrder5P(HigherOrderMethodBase):
175
+ """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
176
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
177
+ defaults=dict(lstsq=lstsq)
178
+ super().__init__(defaults=defaults, vectorize=vectorize)
179
+
180
+ def one_iteration(self, x, evaluate, var):
181
+ settings = self.defaults
182
+ lstsq = settings['lstsq']
183
+ def f_j(x): return evaluate(x, 2)[1:]
184
+ x_star = sixth_order_5p(x, f_j, lstsq)
185
+ return x - x_star
186
+
187
+ # 2f 1J 2 solves
188
+ def two_point_newton(x: torch.Tensor, f, f_j, lstsq:bool=False):
189
+ """third order convergence"""
190
+ f_x, J_x = f_j(x)
191
+ y = x - _solve(J_x, f_x, lstsq=lstsq)
192
+ f_y = f(y)
193
+ return x - _solve(J_x, f_x + f_y, lstsq=lstsq)
194
+
195
+ class TwoPointNewton(HigherOrderMethodBase):
196
+ """two-point Newton method with frozen derivative with third order convergence.
197
+
198
+ Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
199
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
200
+ defaults=dict(lstsq=lstsq)
201
+ super().__init__(defaults=defaults, vectorize=vectorize)
202
+
203
+ def one_iteration(self, x, evaluate, var):
204
+ settings = self.defaults
205
+ lstsq = settings['lstsq']
206
+ def f(x): return evaluate(x, 1)[1]
207
+ def f_j(x): return evaluate(x, 2)[1:]
208
+ x_star = two_point_newton(x, f, f_j, lstsq)
209
+ return x - x_star
210
+
211
+ #3f 2J 1inv
212
+ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
213
+ f_x, J_x = f_j(x)
214
+ J_x_inv = _inv(J_x, lstsq=lstsq)
215
+ y = x - J_x_inv @ f_x
216
+ f_y, J_y = f_j(y)
217
+
218
+ I = torch.eye(x.numel(), dtype=x.dtype, device=x.device)
219
+ term = (2*I - J_x_inv @ J_y) @ J_x_inv
220
+ z = y - term @ f_y
221
+
222
+ return z - term @ f(z)
223
+
224
+
225
+ class SixthOrder3PM2(HigherOrderMethodBase):
226
+ """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
227
+ def __init__(self, lstsq: bool=False, vectorize: bool = True):
228
+ defaults=dict(lstsq=lstsq)
229
+ super().__init__(defaults=defaults, vectorize=vectorize)
230
+
231
+ def one_iteration(self, x, evaluate, var):
232
+ settings = self.defaults
233
+ lstsq = settings['lstsq']
234
+ def f_j(x): return evaluate(x, 2)[1:]
235
+ def f(x): return evaluate(x, 1)[1]
236
+ x_star = sixth_order_3pm2(x, f, f_j, lstsq)
237
+ return x - x_star
238
+