torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import final, Literal, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Var
|
|
8
|
+
from ...utils import TensorList
|
|
9
|
+
from ..termination import TerminationCriteriaBase
|
|
10
|
+
|
|
11
|
+
def _reset_except_self(optimizer, var, self: Module):
|
|
12
|
+
for m in optimizer.unrolled_modules: m.reset()
|
|
13
|
+
|
|
14
|
+
class RestartStrategyBase(Module, ABC):
|
|
15
|
+
"""Base class for restart strategies.
|
|
16
|
+
|
|
17
|
+
On each ``update``/``step`` this checks reset condition and if it is satisfied,
|
|
18
|
+
resets the modules before updating or stepping.
|
|
19
|
+
"""
|
|
20
|
+
def __init__(self, defaults: dict | None = None, modules: Chainable | None = None):
|
|
21
|
+
if defaults is None: defaults = {}
|
|
22
|
+
super().__init__(defaults)
|
|
23
|
+
if modules is not None:
|
|
24
|
+
self.set_child('modules', modules)
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def should_reset(self, var: Var) -> bool:
|
|
28
|
+
"""returns whether reset should occur"""
|
|
29
|
+
|
|
30
|
+
def _reset_on_condition(self, var):
|
|
31
|
+
modules = self.children.get('modules', None)
|
|
32
|
+
|
|
33
|
+
if self.should_reset(var):
|
|
34
|
+
if modules is None:
|
|
35
|
+
var.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
36
|
+
else:
|
|
37
|
+
modules.reset()
|
|
38
|
+
|
|
39
|
+
return modules
|
|
40
|
+
|
|
41
|
+
@final
|
|
42
|
+
def update(self, var):
|
|
43
|
+
modules = self._reset_on_condition(var)
|
|
44
|
+
if modules is not None:
|
|
45
|
+
modules.update(var)
|
|
46
|
+
|
|
47
|
+
@final
|
|
48
|
+
def apply(self, var):
|
|
49
|
+
# don't check here because it was check in `update`
|
|
50
|
+
modules = self.children.get('modules', None)
|
|
51
|
+
if modules is None: return var
|
|
52
|
+
return modules.apply(var.clone(clone_update=False))
|
|
53
|
+
|
|
54
|
+
@final
|
|
55
|
+
def step(self, var):
|
|
56
|
+
modules = self._reset_on_condition(var)
|
|
57
|
+
if modules is None: return var
|
|
58
|
+
return modules.step(var.clone(clone_update=False))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class RestartOnStuck(RestartStrategyBase):
|
|
63
|
+
"""Resets the state when update (difference in parameters) is close to zero for multiple steps in a row.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
modules (Chainable | None):
|
|
67
|
+
modules to reset. If None, resets all modules.
|
|
68
|
+
tol (float, optional):
|
|
69
|
+
step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to 1e-10.
|
|
70
|
+
n_tol (int, optional):
|
|
71
|
+
number of failed consequtive steps required to trigger a reset. Defaults to 4.
|
|
72
|
+
|
|
73
|
+
"""
|
|
74
|
+
def __init__(self, modules: Chainable | None, tol: float = 1e-10, n_tol: int = 4):
|
|
75
|
+
defaults = dict(tol=tol, n_tol=n_tol)
|
|
76
|
+
super().__init__(defaults, modules)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def should_reset(self, var):
|
|
80
|
+
step = self.global_state.get('step', 0)
|
|
81
|
+
self.global_state['step'] = step + 1
|
|
82
|
+
|
|
83
|
+
params = TensorList(var.params)
|
|
84
|
+
tol = self.defaults['tol']
|
|
85
|
+
n_tol = self.defaults['n_tol']
|
|
86
|
+
n_bad = self.global_state.get('n_bad', 0)
|
|
87
|
+
|
|
88
|
+
# calculate difference in parameters
|
|
89
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList)
|
|
90
|
+
update = params - prev_params
|
|
91
|
+
prev_params.copy_(params)
|
|
92
|
+
|
|
93
|
+
# if update is too small, it is considered bad, otherwise n_bad is reset to 0
|
|
94
|
+
if step > 0:
|
|
95
|
+
if update.abs().global_max() <= tol:
|
|
96
|
+
n_bad += 1
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
n_bad = 0
|
|
100
|
+
|
|
101
|
+
self.global_state['n_bad'] = n_bad
|
|
102
|
+
|
|
103
|
+
# no progress, reset
|
|
104
|
+
if n_bad >= n_tol:
|
|
105
|
+
self.global_state.clear()
|
|
106
|
+
return True
|
|
107
|
+
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class RestartEvery(RestartStrategyBase):
|
|
112
|
+
"""Resets the state every n steps
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
modules (Chainable | None):
|
|
116
|
+
modules to reset. If None, resets all modules.
|
|
117
|
+
steps (int | Literal["ndim"]):
|
|
118
|
+
number of steps between resets. "ndim" to use number of parameters.
|
|
119
|
+
"""
|
|
120
|
+
def __init__(self, modules: Chainable | None, steps: int | Literal['ndim']):
|
|
121
|
+
defaults = dict(steps=steps)
|
|
122
|
+
super().__init__(defaults, modules)
|
|
123
|
+
|
|
124
|
+
def should_reset(self, var):
|
|
125
|
+
step = self.global_state.get('step', 0) + 1
|
|
126
|
+
self.global_state['step'] = step
|
|
127
|
+
|
|
128
|
+
n = self.defaults['steps']
|
|
129
|
+
if isinstance(n, str): n = sum(p.numel() for p in var.params if p.requires_grad)
|
|
130
|
+
|
|
131
|
+
# reset every n steps
|
|
132
|
+
if step % n == 0:
|
|
133
|
+
self.global_state.clear()
|
|
134
|
+
return True
|
|
135
|
+
|
|
136
|
+
return False
|
|
137
|
+
|
|
138
|
+
class RestartOnTerminationCriteria(RestartStrategyBase):
|
|
139
|
+
def __init__(self, modules: Chainable | None, criteria: "TerminationCriteriaBase"):
|
|
140
|
+
super().__init__(None, modules)
|
|
141
|
+
self.set_child('criteria', criteria)
|
|
142
|
+
|
|
143
|
+
def should_reset(self, var):
|
|
144
|
+
criteria = cast(TerminationCriteriaBase, self.children['criteria'])
|
|
145
|
+
return criteria.should_terminate(var)
|
|
146
|
+
|
|
147
|
+
class PowellRestart(RestartStrategyBase):
|
|
148
|
+
"""Powell's two restarting criterions for conjugate gradient methods.
|
|
149
|
+
|
|
150
|
+
The restart clears all states of ``modules``.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
modules (Chainable | None):
|
|
154
|
+
modules to reset. If None, resets all modules.
|
|
155
|
+
cond1 (float | None, optional):
|
|
156
|
+
criterion that checks for nonconjugacy of the search directions.
|
|
157
|
+
Restart is performed whenevr g^Tg_{k+1} >= cond1*||g_{k+1}||^2.
|
|
158
|
+
The default condition value of 0.2 is suggested by Powell. Can be None to disable that criterion.
|
|
159
|
+
cond2 (float | None, optional):
|
|
160
|
+
criterion that checks if direction is not effectively downhill.
|
|
161
|
+
Restart is performed if -1.2||g||^2 < d^Tg < -0.8||g||^2.
|
|
162
|
+
Defaults to 0.2. Can be None to disable that criterion.
|
|
163
|
+
|
|
164
|
+
Reference:
|
|
165
|
+
Powell, Michael James David. "Restart procedures for the conjugate gradient method." Mathematical programming 12.1 (1977): 241-254.
|
|
166
|
+
"""
|
|
167
|
+
def __init__(self, modules: Chainable | None, cond1:float | None = 0.2, cond2:float | None = 0.2):
|
|
168
|
+
defaults=dict(cond1=cond1, cond2=cond2)
|
|
169
|
+
super().__init__(defaults, modules)
|
|
170
|
+
|
|
171
|
+
def should_reset(self, var):
|
|
172
|
+
g = TensorList(var.get_grad())
|
|
173
|
+
cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
|
|
174
|
+
|
|
175
|
+
# -------------------------------- initialize -------------------------------- #
|
|
176
|
+
if 'initialized' not in self.global_state:
|
|
177
|
+
self.global_state['initialized'] = 0
|
|
178
|
+
g_prev = self.get_state(var.params, 'g_prev', init=g)
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
g_g = g.dot(g)
|
|
182
|
+
|
|
183
|
+
reset = False
|
|
184
|
+
# ------------------------------- 1st condition ------------------------------ #
|
|
185
|
+
if cond1 is not None:
|
|
186
|
+
g_prev = self.get_state(var.params, 'g_prev', must_exist=True, cls=TensorList)
|
|
187
|
+
g_g_prev = g_prev.dot(g)
|
|
188
|
+
|
|
189
|
+
if g_g_prev.abs() >= cond1 * g_g:
|
|
190
|
+
reset = True
|
|
191
|
+
|
|
192
|
+
# ------------------------------- 2nd condition ------------------------------ #
|
|
193
|
+
if (cond2 is not None) and (not reset):
|
|
194
|
+
d_g = TensorList(var.get_update()).dot(g)
|
|
195
|
+
if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
|
|
196
|
+
reset = True
|
|
197
|
+
|
|
198
|
+
# ------------------------------ clear on reset ------------------------------ #
|
|
199
|
+
if reset:
|
|
200
|
+
self.global_state.clear()
|
|
201
|
+
self.clear_state_keys('g_prev')
|
|
202
|
+
return True
|
|
203
|
+
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
# this requires direction from inner module,
|
|
207
|
+
# so both this and inner have to be a single module
|
|
208
|
+
class BirginMartinezRestart(Module):
|
|
209
|
+
"""the restart criterion for conjugate gradient methods designed by Birgin and Martinez.
|
|
210
|
+
|
|
211
|
+
This criterion restarts when when the angle between dk+1 and −gk+1 is not acute enough.
|
|
212
|
+
|
|
213
|
+
The restart clears all states of ``module``.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
module (Module):
|
|
217
|
+
module to restart, should be a conjugate gradient or possibly a quasi-newton method.
|
|
218
|
+
cond (float, optional):
|
|
219
|
+
Restart is performed whenevr d^Tg > -cond*||d||*||g||.
|
|
220
|
+
The default condition value of 1e-3 is suggested by Birgin and Martinez.
|
|
221
|
+
|
|
222
|
+
Reference:
|
|
223
|
+
Birgin, Ernesto G., and José Mario Martínez. "A spectral conjugate gradient method for unconstrained optimization." Applied Mathematics & Optimization 43.2 (2001): 117-128.
|
|
224
|
+
"""
|
|
225
|
+
def __init__(self, module: Module, cond:float = 1e-3):
|
|
226
|
+
defaults=dict(cond=cond)
|
|
227
|
+
super().__init__(defaults)
|
|
228
|
+
|
|
229
|
+
self.set_child("module", module)
|
|
230
|
+
|
|
231
|
+
def update(self, var):
|
|
232
|
+
module = self.children['module']
|
|
233
|
+
module.update(var)
|
|
234
|
+
|
|
235
|
+
def apply(self, var):
|
|
236
|
+
module = self.children['module']
|
|
237
|
+
var = module.apply(var.clone(clone_update=False))
|
|
238
|
+
|
|
239
|
+
cond = self.defaults['cond']
|
|
240
|
+
g = TensorList(var.get_grad())
|
|
241
|
+
d = TensorList(var.get_update())
|
|
242
|
+
d_g = d.dot(g)
|
|
243
|
+
d_norm = d.global_vector_norm()
|
|
244
|
+
g_norm = g.global_vector_norm()
|
|
245
|
+
|
|
246
|
+
# d in our case is same direction as g so it has a minus sign
|
|
247
|
+
if -d_g > -cond * d_norm * g_norm:
|
|
248
|
+
module.reset()
|
|
249
|
+
var.update = g.clone()
|
|
250
|
+
return var
|
|
251
|
+
|
|
252
|
+
return var
|
|
@@ -1,3 +1,4 @@
|
|
|
1
1
|
from .newton import Newton, InverseFreeNewton
|
|
2
|
-
from .newton_cg import NewtonCG,
|
|
2
|
+
from .newton_cg import NewtonCG, NewtonCGSteihaug
|
|
3
3
|
from .nystrom import NystromSketchAndSolve, NystromPCG
|
|
4
|
+
from .multipoint import SixthOrder3P, SixthOrder5P, TwoPointNewton, SixthOrder3PM2
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, apply_transform, Var
|
|
8
|
+
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
9
|
+
from ...utils.derivatives import (
|
|
10
|
+
flatten_jacobian,
|
|
11
|
+
jacobian_wrt,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
class HigherOrderMethodBase(Module, ABC):
|
|
15
|
+
def __init__(self, defaults: dict | None = None, vectorize: bool = True):
|
|
16
|
+
self._vectorize = vectorize
|
|
17
|
+
super().__init__(defaults)
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def one_iteration(
|
|
21
|
+
self,
|
|
22
|
+
x: torch.Tensor,
|
|
23
|
+
evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
|
|
24
|
+
var: Var,
|
|
25
|
+
) -> torch.Tensor:
|
|
26
|
+
""""""
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, var):
|
|
30
|
+
params = TensorList(var.params)
|
|
31
|
+
x0 = params.clone()
|
|
32
|
+
closure = var.closure
|
|
33
|
+
if closure is None: raise RuntimeError('MultipointNewton requires closure')
|
|
34
|
+
vectorize = self._vectorize
|
|
35
|
+
|
|
36
|
+
def evaluate(x, order) -> tuple[torch.Tensor, ...]:
|
|
37
|
+
"""order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
|
|
38
|
+
params.from_vec_(x)
|
|
39
|
+
|
|
40
|
+
if order == 0:
|
|
41
|
+
loss = closure(False)
|
|
42
|
+
params.copy_(x0)
|
|
43
|
+
return (loss, )
|
|
44
|
+
|
|
45
|
+
if order == 1:
|
|
46
|
+
with torch.enable_grad():
|
|
47
|
+
loss = closure()
|
|
48
|
+
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
49
|
+
params.copy_(x0)
|
|
50
|
+
return loss, torch.cat([g.ravel() for g in grad])
|
|
51
|
+
|
|
52
|
+
with torch.enable_grad():
|
|
53
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
54
|
+
|
|
55
|
+
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
56
|
+
var.grad = list(g_list)
|
|
57
|
+
|
|
58
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
59
|
+
n = g.numel()
|
|
60
|
+
ret = [loss, g]
|
|
61
|
+
T = g # current derivatives tensor
|
|
62
|
+
|
|
63
|
+
# get all derivative up to order
|
|
64
|
+
for o in range(2, order + 1):
|
|
65
|
+
is_last = o == order
|
|
66
|
+
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
67
|
+
with torch.no_grad() if is_last else nullcontext():
|
|
68
|
+
# the shape is (ndim, ) * order
|
|
69
|
+
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
70
|
+
ret.append(T)
|
|
71
|
+
|
|
72
|
+
params.copy_(x0)
|
|
73
|
+
return tuple(ret)
|
|
74
|
+
|
|
75
|
+
x = torch.cat([p.ravel() for p in params])
|
|
76
|
+
dir = self.one_iteration(x, evaluate, var)
|
|
77
|
+
var.update = vec_to_tensors(dir, var.params)
|
|
78
|
+
return var
|
|
79
|
+
|
|
80
|
+
def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
|
|
81
|
+
if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
|
|
82
|
+
A_inv, info = torch.linalg.inv_ex(A) # pylint:disable=not-callable
|
|
83
|
+
if info == 0: return A_inv
|
|
84
|
+
return torch.linalg.pinv(A) # pylint:disable=not-callable
|
|
85
|
+
|
|
86
|
+
def _solve(A: torch.Tensor, b: torch.Tensor, lstsq: bool) -> torch.Tensor:
|
|
87
|
+
if lstsq: return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
|
|
88
|
+
sol, info = torch.linalg.solve_ex(A, b) # pylint:disable=not-callable
|
|
89
|
+
if info == 0: return sol
|
|
90
|
+
return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
|
|
91
|
+
|
|
92
|
+
# 3f 2J 3 solves
|
|
93
|
+
def sixth_order_3p(x:torch.Tensor, f, f_j, lstsq:bool=False):
|
|
94
|
+
f_x, J_x = f_j(x)
|
|
95
|
+
|
|
96
|
+
y = x - _solve(J_x, f_x, lstsq=lstsq)
|
|
97
|
+
f_y, J_y = f_j(y)
|
|
98
|
+
|
|
99
|
+
z = y - _solve(J_y, f_y, lstsq=lstsq)
|
|
100
|
+
f_z = f(z)
|
|
101
|
+
|
|
102
|
+
return y - _solve(J_y, f_y+f_z, lstsq=lstsq)
|
|
103
|
+
|
|
104
|
+
class SixthOrder3P(HigherOrderMethodBase):
|
|
105
|
+
"""Sixth-order iterative method.
|
|
106
|
+
|
|
107
|
+
Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
|
|
108
|
+
"""
|
|
109
|
+
def __init__(self, lstsq: bool=False, vectorize: bool = True):
|
|
110
|
+
defaults=dict(lstsq=lstsq)
|
|
111
|
+
super().__init__(defaults=defaults, vectorize=vectorize)
|
|
112
|
+
|
|
113
|
+
def one_iteration(self, x, evaluate, var):
|
|
114
|
+
settings = self.defaults
|
|
115
|
+
lstsq = settings['lstsq']
|
|
116
|
+
def f(x): return evaluate(x, 1)[1]
|
|
117
|
+
def f_j(x): return evaluate(x, 2)[1:]
|
|
118
|
+
x_star = sixth_order_3p(x, f, f_j, lstsq)
|
|
119
|
+
return x - x_star
|
|
120
|
+
|
|
121
|
+
# I don't think it works (I tested root finding with this and it goes all over the place)
|
|
122
|
+
# I double checked it multiple times
|
|
123
|
+
# def sixth_order_im1(x:torch.Tensor, f, f_j, lstsq:bool=False):
|
|
124
|
+
# f_x, J_x = f_j(x)
|
|
125
|
+
# J_x_inv = _inv(J_x, lstsq=lstsq)
|
|
126
|
+
|
|
127
|
+
# y = x - J_x_inv @ f_x
|
|
128
|
+
# f_y, J_y = f_j(y)
|
|
129
|
+
|
|
130
|
+
# z = x - 2 * _solve(J_x + J_y, f_x, lstsq=lstsq)
|
|
131
|
+
# f_z = f(z)
|
|
132
|
+
|
|
133
|
+
# I = torch.eye(J_y.size(0), device=J_y.device, dtype=J_y.dtype)
|
|
134
|
+
# term1 = (7/2)*I
|
|
135
|
+
# term2 = 4 * (J_x_inv@J_y)
|
|
136
|
+
# term3 = (3/2) * (J_x_inv @ (J_y@J_y))
|
|
137
|
+
|
|
138
|
+
# return z - (term1 - term2 + term3) @ J_x_inv @ f_z
|
|
139
|
+
|
|
140
|
+
# class SixthOrderIM1(HigherOrderMethodBase):
|
|
141
|
+
# """sixth-order iterative method https://www.mdpi.com/2504-3110/8/3/133
|
|
142
|
+
|
|
143
|
+
# """
|
|
144
|
+
# def __init__(self, lstsq: bool=False, vectorize: bool = True):
|
|
145
|
+
# defaults=dict(lstsq=lstsq)
|
|
146
|
+
# super().__init__(defaults=defaults, vectorize=vectorize)
|
|
147
|
+
|
|
148
|
+
# def iteration(self, x, evaluate, var):
|
|
149
|
+
# settings = self.defaults
|
|
150
|
+
# lstsq = settings['lstsq']
|
|
151
|
+
# def f(x): return evaluate(x, 1)[1]
|
|
152
|
+
# def f_j(x): return evaluate(x, 2)[1:]
|
|
153
|
+
# x_star = sixth_order_im1(x, f, f_j, lstsq)
|
|
154
|
+
# return x - x_star
|
|
155
|
+
|
|
156
|
+
# 5f 5J 3 solves
|
|
157
|
+
def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
|
|
158
|
+
f_x, J_x = f_j(x)
|
|
159
|
+
y = x - _solve(J_x, f_x, lstsq=lstsq)
|
|
160
|
+
|
|
161
|
+
f_y, J_y = f_j(y)
|
|
162
|
+
f_xy2, J_xy2 = f_j((x + y) / 2)
|
|
163
|
+
|
|
164
|
+
A = J_x + 2*J_xy2 + J_y
|
|
165
|
+
|
|
166
|
+
z = y - 4*_solve(A, f_y, lstsq=lstsq)
|
|
167
|
+
f_z, J_z = f_j(z)
|
|
168
|
+
|
|
169
|
+
f_xz2, J_xz2 = f_j((x + z) / 2)
|
|
170
|
+
B = J_x + 2*J_xz2 + J_z
|
|
171
|
+
|
|
172
|
+
return z - 4*_solve(B, f_z, lstsq=lstsq)
|
|
173
|
+
|
|
174
|
+
class SixthOrder5P(HigherOrderMethodBase):
|
|
175
|
+
"""Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
|
|
176
|
+
def __init__(self, lstsq: bool=False, vectorize: bool = True):
|
|
177
|
+
defaults=dict(lstsq=lstsq)
|
|
178
|
+
super().__init__(defaults=defaults, vectorize=vectorize)
|
|
179
|
+
|
|
180
|
+
def one_iteration(self, x, evaluate, var):
|
|
181
|
+
settings = self.defaults
|
|
182
|
+
lstsq = settings['lstsq']
|
|
183
|
+
def f_j(x): return evaluate(x, 2)[1:]
|
|
184
|
+
x_star = sixth_order_5p(x, f_j, lstsq)
|
|
185
|
+
return x - x_star
|
|
186
|
+
|
|
187
|
+
# 2f 1J 2 solves
|
|
188
|
+
def two_point_newton(x: torch.Tensor, f, f_j, lstsq:bool=False):
|
|
189
|
+
"""third order convergence"""
|
|
190
|
+
f_x, J_x = f_j(x)
|
|
191
|
+
y = x - _solve(J_x, f_x, lstsq=lstsq)
|
|
192
|
+
f_y = f(y)
|
|
193
|
+
return x - _solve(J_x, f_x + f_y, lstsq=lstsq)
|
|
194
|
+
|
|
195
|
+
class TwoPointNewton(HigherOrderMethodBase):
|
|
196
|
+
"""two-point Newton method with frozen derivative with third order convergence.
|
|
197
|
+
|
|
198
|
+
Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
|
|
199
|
+
def __init__(self, lstsq: bool=False, vectorize: bool = True):
|
|
200
|
+
defaults=dict(lstsq=lstsq)
|
|
201
|
+
super().__init__(defaults=defaults, vectorize=vectorize)
|
|
202
|
+
|
|
203
|
+
def one_iteration(self, x, evaluate, var):
|
|
204
|
+
settings = self.defaults
|
|
205
|
+
lstsq = settings['lstsq']
|
|
206
|
+
def f(x): return evaluate(x, 1)[1]
|
|
207
|
+
def f_j(x): return evaluate(x, 2)[1:]
|
|
208
|
+
x_star = two_point_newton(x, f, f_j, lstsq)
|
|
209
|
+
return x - x_star
|
|
210
|
+
|
|
211
|
+
#3f 2J 1inv
|
|
212
|
+
def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
|
|
213
|
+
f_x, J_x = f_j(x)
|
|
214
|
+
J_x_inv = _inv(J_x, lstsq=lstsq)
|
|
215
|
+
y = x - J_x_inv @ f_x
|
|
216
|
+
f_y, J_y = f_j(y)
|
|
217
|
+
|
|
218
|
+
I = torch.eye(x.numel(), dtype=x.dtype, device=x.device)
|
|
219
|
+
term = (2*I - J_x_inv @ J_y) @ J_x_inv
|
|
220
|
+
z = y - term @ f_y
|
|
221
|
+
|
|
222
|
+
return z - term @ f(z)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class SixthOrder3PM2(HigherOrderMethodBase):
|
|
226
|
+
"""Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
|
|
227
|
+
def __init__(self, lstsq: bool=False, vectorize: bool = True):
|
|
228
|
+
defaults=dict(lstsq=lstsq)
|
|
229
|
+
super().__init__(defaults=defaults, vectorize=vectorize)
|
|
230
|
+
|
|
231
|
+
def one_iteration(self, x, evaluate, var):
|
|
232
|
+
settings = self.defaults
|
|
233
|
+
lstsq = settings['lstsq']
|
|
234
|
+
def f_j(x): return evaluate(x, 2)[1:]
|
|
235
|
+
def f(x): return evaluate(x, 1)[1]
|
|
236
|
+
x_star = sixth_order_3pm2(x, f, f_j, lstsq)
|
|
237
|
+
return x - x_star
|
|
238
|
+
|