torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -4,12 +4,14 @@ from typing import final, Literal, cast
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
7
|
+
from ...core import Chainable, Module, Objective
|
|
8
8
|
from ...utils import TensorList
|
|
9
9
|
from ..termination import TerminationCriteriaBase
|
|
10
10
|
|
|
11
|
-
def _reset_except_self(
|
|
12
|
-
for m in
|
|
11
|
+
def _reset_except_self(objective, modules, self: Module):
|
|
12
|
+
for m in modules:
|
|
13
|
+
if m is not self:
|
|
14
|
+
m.reset()
|
|
13
15
|
|
|
14
16
|
class RestartStrategyBase(Module, ABC):
|
|
15
17
|
"""Base class for restart strategies.
|
|
@@ -24,38 +26,38 @@ class RestartStrategyBase(Module, ABC):
|
|
|
24
26
|
self.set_child('modules', modules)
|
|
25
27
|
|
|
26
28
|
@abstractmethod
|
|
27
|
-
def should_reset(self,
|
|
29
|
+
def should_reset(self, objective: Objective) -> bool:
|
|
28
30
|
"""returns whether reset should occur"""
|
|
29
31
|
|
|
30
|
-
def _reset_on_condition(self,
|
|
32
|
+
def _reset_on_condition(self, objective: Objective):
|
|
31
33
|
modules = self.children.get('modules', None)
|
|
32
34
|
|
|
33
|
-
if self.should_reset(
|
|
35
|
+
if self.should_reset(objective):
|
|
34
36
|
if modules is None:
|
|
35
|
-
|
|
37
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
36
38
|
else:
|
|
37
39
|
modules.reset()
|
|
38
40
|
|
|
39
41
|
return modules
|
|
40
42
|
|
|
41
43
|
@final
|
|
42
|
-
def update(self,
|
|
43
|
-
modules = self._reset_on_condition(
|
|
44
|
+
def update(self, objective):
|
|
45
|
+
modules = self._reset_on_condition(objective)
|
|
44
46
|
if modules is not None:
|
|
45
|
-
modules.update(
|
|
47
|
+
modules.update(objective)
|
|
46
48
|
|
|
47
49
|
@final
|
|
48
|
-
def apply(self,
|
|
50
|
+
def apply(self, objective):
|
|
49
51
|
# don't check here because it was check in `update`
|
|
50
52
|
modules = self.children.get('modules', None)
|
|
51
|
-
if modules is None: return
|
|
52
|
-
return modules.apply(
|
|
53
|
+
if modules is None: return objective
|
|
54
|
+
return modules.apply(objective.clone(clone_updates=False))
|
|
53
55
|
|
|
54
56
|
@final
|
|
55
|
-
def step(self,
|
|
56
|
-
modules = self._reset_on_condition(
|
|
57
|
-
if modules is None: return
|
|
58
|
-
return modules.step(
|
|
57
|
+
def step(self, objective):
|
|
58
|
+
modules = self._reset_on_condition(objective)
|
|
59
|
+
if modules is None: return objective
|
|
60
|
+
return modules.step(objective.clone(clone_updates=False))
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
|
|
@@ -76,11 +78,11 @@ class RestartOnStuck(RestartStrategyBase):
|
|
|
76
78
|
super().__init__(defaults, modules)
|
|
77
79
|
|
|
78
80
|
@torch.no_grad
|
|
79
|
-
def should_reset(self,
|
|
81
|
+
def should_reset(self, objective):
|
|
80
82
|
step = self.global_state.get('step', 0)
|
|
81
83
|
self.global_state['step'] = step + 1
|
|
82
84
|
|
|
83
|
-
params = TensorList(
|
|
85
|
+
params = TensorList(objective.params)
|
|
84
86
|
tol = self.defaults['tol']
|
|
85
87
|
if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
|
|
86
88
|
n_tol = self.defaults['n_tol']
|
|
@@ -122,12 +124,12 @@ class RestartEvery(RestartStrategyBase):
|
|
|
122
124
|
defaults = dict(steps=steps)
|
|
123
125
|
super().__init__(defaults, modules)
|
|
124
126
|
|
|
125
|
-
def should_reset(self,
|
|
127
|
+
def should_reset(self, objective):
|
|
126
128
|
step = self.global_state.get('step', 0) + 1
|
|
127
129
|
self.global_state['step'] = step
|
|
128
130
|
|
|
129
131
|
n = self.defaults['steps']
|
|
130
|
-
if isinstance(n, str): n = sum(p.numel() for p in
|
|
132
|
+
if isinstance(n, str): n = sum(p.numel() for p in objective.params if p.requires_grad)
|
|
131
133
|
|
|
132
134
|
# reset every n steps
|
|
133
135
|
if step % n == 0:
|
|
@@ -141,9 +143,9 @@ class RestartOnTerminationCriteria(RestartStrategyBase):
|
|
|
141
143
|
super().__init__(None, modules)
|
|
142
144
|
self.set_child('criteria', criteria)
|
|
143
145
|
|
|
144
|
-
def should_reset(self,
|
|
146
|
+
def should_reset(self, objective):
|
|
145
147
|
criteria = cast(TerminationCriteriaBase, self.children['criteria'])
|
|
146
|
-
return criteria.should_terminate(
|
|
148
|
+
return criteria.should_terminate(objective)
|
|
147
149
|
|
|
148
150
|
class PowellRestart(RestartStrategyBase):
|
|
149
151
|
"""Powell's two restarting criterions for conjugate gradient methods.
|
|
@@ -169,14 +171,14 @@ class PowellRestart(RestartStrategyBase):
|
|
|
169
171
|
defaults=dict(cond1=cond1, cond2=cond2)
|
|
170
172
|
super().__init__(defaults, modules)
|
|
171
173
|
|
|
172
|
-
def should_reset(self,
|
|
173
|
-
g = TensorList(
|
|
174
|
+
def should_reset(self, objective):
|
|
175
|
+
g = TensorList(objective.get_grads())
|
|
174
176
|
cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
|
|
175
177
|
|
|
176
178
|
# -------------------------------- initialize -------------------------------- #
|
|
177
179
|
if 'initialized' not in self.global_state:
|
|
178
180
|
self.global_state['initialized'] = 0
|
|
179
|
-
g_prev = self.get_state(
|
|
181
|
+
g_prev = self.get_state(objective.params, 'g_prev', init=g)
|
|
180
182
|
return False
|
|
181
183
|
|
|
182
184
|
g_g = g.dot(g)
|
|
@@ -184,7 +186,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
184
186
|
reset = False
|
|
185
187
|
# ------------------------------- 1st condition ------------------------------ #
|
|
186
188
|
if cond1 is not None:
|
|
187
|
-
g_prev = self.get_state(
|
|
189
|
+
g_prev = self.get_state(objective.params, 'g_prev', must_exist=True, cls=TensorList)
|
|
188
190
|
g_g_prev = g_prev.dot(g)
|
|
189
191
|
|
|
190
192
|
if g_g_prev.abs() >= cond1 * g_g:
|
|
@@ -192,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
192
194
|
|
|
193
195
|
# ------------------------------- 2nd condition ------------------------------ #
|
|
194
196
|
if (cond2 is not None) and (not reset):
|
|
195
|
-
d_g = TensorList(
|
|
197
|
+
d_g = TensorList(objective.get_updates()).dot(g)
|
|
196
198
|
if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
|
|
197
199
|
reset = True
|
|
198
200
|
|
|
@@ -229,17 +231,17 @@ class BirginMartinezRestart(Module):
|
|
|
229
231
|
|
|
230
232
|
self.set_child("module", module)
|
|
231
233
|
|
|
232
|
-
def update(self,
|
|
234
|
+
def update(self, objective):
|
|
233
235
|
module = self.children['module']
|
|
234
|
-
module.update(
|
|
236
|
+
module.update(objective)
|
|
235
237
|
|
|
236
|
-
def apply(self,
|
|
238
|
+
def apply(self, objective):
|
|
237
239
|
module = self.children['module']
|
|
238
|
-
|
|
240
|
+
objective = module.apply(objective.clone(clone_updates=False))
|
|
239
241
|
|
|
240
242
|
cond = self.defaults['cond']
|
|
241
|
-
g = TensorList(
|
|
242
|
-
d = TensorList(
|
|
243
|
+
g = TensorList(objective.get_grads())
|
|
244
|
+
d = TensorList(objective.get_updates())
|
|
243
245
|
d_g = d.dot(g)
|
|
244
246
|
d_norm = d.global_vector_norm()
|
|
245
247
|
g_norm = g.global_vector_norm()
|
|
@@ -247,7 +249,7 @@ class BirginMartinezRestart(Module):
|
|
|
247
249
|
# d in our case is same direction as g so it has a minus sign
|
|
248
250
|
if -d_g > -cond * d_norm * g_norm:
|
|
249
251
|
module.reset()
|
|
250
|
-
|
|
251
|
-
return
|
|
252
|
+
objective.updates = g.clone()
|
|
253
|
+
return objective
|
|
252
254
|
|
|
253
|
-
return
|
|
255
|
+
return objective
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from .ifn import InverseFreeNewton
|
|
2
|
-
from .inm import
|
|
2
|
+
from .inm import ImprovedNewton
|
|
3
3
|
from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
|
|
4
4
|
from .newton import Newton
|
|
5
5
|
from .newton_cg import NewtonCG, NewtonCGSteihaug
|
|
6
6
|
from .nystrom import NystromPCG, NystromSketchAndSolve
|
|
7
|
-
from .rsn import
|
|
7
|
+
from .rsn import SubspaceNewton
|
|
@@ -1,89 +1,58 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from collections.abc import Callable
|
|
3
|
-
from functools import partial
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
1
|
import torch
|
|
7
2
|
|
|
8
|
-
from ...core import Chainable,
|
|
3
|
+
from ...core import Chainable, Transform, HessianMethod
|
|
9
4
|
from ...utils import TensorList, vec_to_tensors
|
|
10
|
-
from ...
|
|
11
|
-
from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
|
|
5
|
+
from ...linalg.linear_operator import DenseWithInverse
|
|
12
6
|
|
|
13
7
|
|
|
14
|
-
class InverseFreeNewton(
|
|
8
|
+
class InverseFreeNewton(Transform):
|
|
15
9
|
"""Inverse-free newton's method
|
|
16
10
|
|
|
17
|
-
.. note::
|
|
18
|
-
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
19
|
-
|
|
20
|
-
.. note::
|
|
21
|
-
This module requires the a closure passed to the optimizer step,
|
|
22
|
-
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
23
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
24
|
-
|
|
25
|
-
.. warning::
|
|
26
|
-
this uses roughly O(N^2) memory.
|
|
27
|
-
|
|
28
11
|
Reference
|
|
29
12
|
[Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
|
|
30
13
|
"""
|
|
31
14
|
def __init__(
|
|
32
15
|
self,
|
|
33
16
|
update_freq: int = 1,
|
|
34
|
-
hessian_method:
|
|
35
|
-
|
|
17
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
18
|
+
h: float = 1e-3,
|
|
36
19
|
inner: Chainable | None = None,
|
|
37
20
|
):
|
|
38
|
-
defaults = dict(hessian_method=hessian_method,
|
|
39
|
-
super().__init__(defaults)
|
|
40
|
-
|
|
41
|
-
if inner is not None:
|
|
42
|
-
self.set_child('inner', inner)
|
|
21
|
+
defaults = dict(hessian_method=hessian_method, h=h)
|
|
22
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
43
23
|
|
|
44
24
|
@torch.no_grad
|
|
45
|
-
def
|
|
46
|
-
|
|
25
|
+
def update_states(self, objective, states, settings):
|
|
26
|
+
fs = settings[0]
|
|
47
27
|
|
|
48
|
-
|
|
49
|
-
|
|
28
|
+
_, _, H = objective.hessian(
|
|
29
|
+
hessian_method=fs['hessian_method'],
|
|
30
|
+
h=fs['h'],
|
|
31
|
+
at_x0=True
|
|
32
|
+
)
|
|
50
33
|
|
|
51
|
-
|
|
52
|
-
loss, g_list, H = _get_loss_grad_and_hessian(
|
|
53
|
-
var, self.defaults['hessian_method'], self.defaults['vectorize']
|
|
54
|
-
)
|
|
55
|
-
self.global_state["H"] = H
|
|
34
|
+
self.global_state["H"] = H
|
|
56
35
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
36
|
+
# inverse free part
|
|
37
|
+
if 'Y' not in self.global_state:
|
|
38
|
+
num = H.T
|
|
39
|
+
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
61
40
|
|
|
62
|
-
|
|
63
|
-
|
|
41
|
+
finfo = torch.finfo(H.dtype)
|
|
42
|
+
self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
64
43
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
44
|
+
else:
|
|
45
|
+
Y = self.global_state['Y']
|
|
46
|
+
I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
47
|
+
I2 -= H @ Y
|
|
48
|
+
self.global_state['Y'] = Y @ I2
|
|
70
49
|
|
|
71
50
|
|
|
72
|
-
def
|
|
51
|
+
def apply_states(self, objective, states, settings):
|
|
73
52
|
Y = self.global_state["Y"]
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
update = var.get_update()
|
|
78
|
-
if 'inner' in self.children:
|
|
79
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
80
|
-
|
|
81
|
-
g = torch.cat([t.ravel() for t in update])
|
|
82
|
-
|
|
83
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
84
|
-
var.update = vec_to_tensors(Y@g, params)
|
|
85
|
-
|
|
86
|
-
return var
|
|
53
|
+
g = torch.cat([t.ravel() for t in objective.get_updates()])
|
|
54
|
+
objective.updates = vec_to_tensors(Y@g, objective.params)
|
|
55
|
+
return objective
|
|
87
56
|
|
|
88
|
-
def get_H(self,
|
|
57
|
+
def get_H(self,objective=...):
|
|
89
58
|
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
|
-
from typing import Literal
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
6
|
-
from ...core import Chainable,
|
|
7
|
-
from ...utils import TensorList,
|
|
8
|
-
from ..
|
|
9
|
-
from .newton import
|
|
5
|
+
from ...core import Chainable, Transform, HessianMethod
|
|
6
|
+
from ...utils import TensorList, vec_to_tensors_, unpack_states
|
|
7
|
+
from ..opt_utils import safe_clip
|
|
8
|
+
from .newton import _newton_update_state_, _newton_solve, _newton_get_H
|
|
10
9
|
|
|
11
10
|
@torch.no_grad
|
|
12
11
|
def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
@@ -25,7 +24,7 @@ def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
|
|
|
25
24
|
L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
|
|
26
25
|
return (Q * L.unsqueeze(-2)) @ Q.mH
|
|
27
26
|
|
|
28
|
-
class
|
|
27
|
+
class ImprovedNewton(Transform):
|
|
29
28
|
"""Improved Newton's Method (INM).
|
|
30
29
|
|
|
31
30
|
Reference:
|
|
@@ -35,71 +34,76 @@ class INM(Module):
|
|
|
35
34
|
def __init__(
|
|
36
35
|
self,
|
|
37
36
|
damping: float = 0,
|
|
38
|
-
|
|
37
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
39
38
|
update_freq: int = 1,
|
|
40
|
-
|
|
41
|
-
|
|
39
|
+
precompute_inverse: bool | None = None,
|
|
40
|
+
use_lstsq: bool = False,
|
|
41
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
42
|
+
h: float = 1e-3,
|
|
42
43
|
inner: Chainable | None = None,
|
|
43
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
44
|
-
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
45
44
|
):
|
|
46
|
-
defaults =
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
if inner is not None:
|
|
50
|
-
self.set_child("inner", inner)
|
|
45
|
+
defaults = locals().copy()
|
|
46
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
47
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner, )
|
|
51
48
|
|
|
52
49
|
@torch.no_grad
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
step = self.global_state.get('step', 0)
|
|
57
|
-
self.global_state['step'] = step + 1
|
|
50
|
+
def update_states(self, objective, states, settings):
|
|
51
|
+
fs = settings[0]
|
|
58
52
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
53
|
+
_, f_list, J = objective.hessian(
|
|
54
|
+
hessian_method=fs['hessian_method'],
|
|
55
|
+
h=fs['h'],
|
|
56
|
+
at_x0=True
|
|
57
|
+
)
|
|
58
|
+
if f_list is None: f_list = objective.get_grads()
|
|
63
59
|
|
|
64
|
-
|
|
65
|
-
|
|
60
|
+
f = torch.cat([t.ravel() for t in f_list])
|
|
61
|
+
J = _eigval_fn(J, fs["eigval_fn"])
|
|
66
62
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
63
|
+
x_list = TensorList(objective.params)
|
|
64
|
+
f_list = TensorList(objective.get_grads())
|
|
65
|
+
x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
|
|
70
66
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return
|
|
67
|
+
# initialize on 1st step, do Newton step
|
|
68
|
+
if "H" not in self.global_state:
|
|
69
|
+
x_prev.copy_(x_list)
|
|
70
|
+
f_prev.copy_(f_list)
|
|
71
|
+
P = J
|
|
77
72
|
|
|
78
|
-
|
|
73
|
+
# INM update
|
|
74
|
+
else:
|
|
79
75
|
s_list = x_list - x_prev
|
|
80
76
|
y_list = f_list - f_prev
|
|
81
77
|
x_prev.copy_(x_list)
|
|
82
78
|
f_prev.copy_(f_list)
|
|
83
79
|
|
|
84
|
-
|
|
80
|
+
P = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
|
|
85
81
|
|
|
82
|
+
# update state
|
|
83
|
+
precompute_inverse = fs["precompute_inverse"]
|
|
84
|
+
if precompute_inverse is None:
|
|
85
|
+
precompute_inverse = fs["__update_freq"] >= 10
|
|
86
86
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
inner=self.children.get("inner", None),
|
|
95
|
-
H_tfm=self.defaults["H_tfm"],
|
|
96
|
-
eigval_fn=None, # it is applied in `update`
|
|
97
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
87
|
+
_newton_update_state_(
|
|
88
|
+
H=P,
|
|
89
|
+
state = self.global_state,
|
|
90
|
+
damping = fs["damping"],
|
|
91
|
+
eigval_fn = fs["eigval_fn"],
|
|
92
|
+
precompute_inverse = precompute_inverse,
|
|
93
|
+
use_lstsq = fs["use_lstsq"]
|
|
98
94
|
)
|
|
99
95
|
|
|
100
|
-
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def apply_states(self, objective, states, settings):
|
|
98
|
+
updates = objective.get_updates()
|
|
99
|
+
fs = settings[0]
|
|
100
|
+
|
|
101
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
102
|
+
sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
|
|
103
|
+
|
|
104
|
+
vec_to_tensors_(sol, updates)
|
|
105
|
+
return objective
|
|
101
106
|
|
|
102
|
-
return var
|
|
103
107
|
|
|
104
|
-
def get_H(self,
|
|
105
|
-
return
|
|
108
|
+
def get_H(self,objective=...):
|
|
109
|
+
return _newton_get_H(self.global_state)
|
|
@@ -1,19 +1,17 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from contextlib import nullcontext
|
|
3
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable, Mapping
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
4
5
|
import numpy as np
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import TensorList, vec_to_tensors
|
|
9
|
-
|
|
10
|
-
flatten_jacobian,
|
|
11
|
-
jacobian_wrt,
|
|
12
|
-
)
|
|
8
|
+
from ...core import Chainable, DerivativesMethod, Objective, Transform
|
|
9
|
+
from ...utils import TensorList, vec_to_tensors
|
|
10
|
+
|
|
13
11
|
|
|
14
|
-
class HigherOrderMethodBase(
|
|
15
|
-
def __init__(self, defaults: dict | None = None,
|
|
16
|
-
self.
|
|
12
|
+
class HigherOrderMethodBase(Transform, ABC):
|
|
13
|
+
def __init__(self, defaults: dict | None = None, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
14
|
+
self._derivatives_method: DerivativesMethod = derivatives_method
|
|
17
15
|
super().__init__(defaults)
|
|
18
16
|
|
|
19
17
|
@abstractmethod
|
|
@@ -21,61 +19,27 @@ class HigherOrderMethodBase(Module, ABC):
|
|
|
21
19
|
self,
|
|
22
20
|
x: torch.Tensor,
|
|
23
21
|
evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
|
|
24
|
-
|
|
22
|
+
objective: Objective,
|
|
23
|
+
setting: Mapping[str, Any],
|
|
25
24
|
) -> torch.Tensor:
|
|
26
25
|
""""""
|
|
27
26
|
|
|
28
27
|
@torch.no_grad
|
|
29
|
-
def
|
|
30
|
-
params = TensorList(
|
|
31
|
-
|
|
32
|
-
closure =
|
|
28
|
+
def apply_states(self, objective, states, settings):
|
|
29
|
+
params = TensorList(objective.params)
|
|
30
|
+
|
|
31
|
+
closure = objective.closure
|
|
33
32
|
if closure is None: raise RuntimeError('MultipointNewton requires closure')
|
|
34
|
-
|
|
33
|
+
derivatives_method = self._derivatives_method
|
|
35
34
|
|
|
36
35
|
def evaluate(x, order) -> tuple[torch.Tensor, ...]:
|
|
37
36
|
"""order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
if order == 0:
|
|
41
|
-
loss = closure(False)
|
|
42
|
-
params.copy_(x0)
|
|
43
|
-
return (loss, )
|
|
44
|
-
|
|
45
|
-
if order == 1:
|
|
46
|
-
with torch.enable_grad():
|
|
47
|
-
loss = closure()
|
|
48
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
49
|
-
params.copy_(x0)
|
|
50
|
-
return loss, torch.cat([g.ravel() for g in grad])
|
|
51
|
-
|
|
52
|
-
with torch.enable_grad():
|
|
53
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
54
|
-
|
|
55
|
-
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
56
|
-
var.grad = list(g_list)
|
|
57
|
-
|
|
58
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
59
|
-
n = g.numel()
|
|
60
|
-
ret = [loss, g]
|
|
61
|
-
T = g # current derivatives tensor
|
|
62
|
-
|
|
63
|
-
# get all derivative up to order
|
|
64
|
-
for o in range(2, order + 1):
|
|
65
|
-
is_last = o == order
|
|
66
|
-
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
67
|
-
with torch.no_grad() if is_last else nullcontext():
|
|
68
|
-
# the shape is (ndim, ) * order
|
|
69
|
-
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
70
|
-
ret.append(T)
|
|
71
|
-
|
|
72
|
-
params.copy_(x0)
|
|
73
|
-
return tuple(ret)
|
|
37
|
+
return objective.derivatives_at(x, order, method=derivatives_method)
|
|
74
38
|
|
|
75
39
|
x = torch.cat([p.ravel() for p in params])
|
|
76
|
-
dir = self.one_iteration(x, evaluate,
|
|
77
|
-
|
|
78
|
-
return
|
|
40
|
+
dir = self.one_iteration(x, evaluate, objective, settings[0])
|
|
41
|
+
objective.updates = vec_to_tensors(dir, objective.params)
|
|
42
|
+
return objective
|
|
79
43
|
|
|
80
44
|
def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
|
|
81
45
|
if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
|
|
@@ -106,16 +70,15 @@ class SixthOrder3P(HigherOrderMethodBase):
|
|
|
106
70
|
|
|
107
71
|
Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
|
|
108
72
|
"""
|
|
109
|
-
def __init__(self, lstsq: bool=False,
|
|
73
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
110
74
|
defaults=dict(lstsq=lstsq)
|
|
111
|
-
super().__init__(defaults=defaults,
|
|
75
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
112
76
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
lstsq = settings['lstsq']
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
116
79
|
def f(x): return evaluate(x, 1)[1]
|
|
117
80
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
118
|
-
x_star = sixth_order_3p(x, f, f_j, lstsq)
|
|
81
|
+
x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
|
|
119
82
|
return x - x_star
|
|
120
83
|
|
|
121
84
|
# I don't think it works (I tested root finding with this and it goes all over the place)
|
|
@@ -173,15 +136,14 @@ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
|
|
|
173
136
|
|
|
174
137
|
class SixthOrder5P(HigherOrderMethodBase):
|
|
175
138
|
"""Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
|
|
176
|
-
def __init__(self, lstsq: bool=False,
|
|
139
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
177
140
|
defaults=dict(lstsq=lstsq)
|
|
178
|
-
super().__init__(defaults=defaults,
|
|
141
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
179
142
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
lstsq = settings['lstsq']
|
|
143
|
+
@torch.no_grad
|
|
144
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
183
145
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
184
|
-
x_star = sixth_order_5p(x, f_j, lstsq)
|
|
146
|
+
x_star = sixth_order_5p(x, f_j, setting['lstsq'])
|
|
185
147
|
return x - x_star
|
|
186
148
|
|
|
187
149
|
# 2f 1J 2 solves
|
|
@@ -196,16 +158,15 @@ class TwoPointNewton(HigherOrderMethodBase):
|
|
|
196
158
|
"""two-point Newton method with frozen derivative with third order convergence.
|
|
197
159
|
|
|
198
160
|
Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
|
|
199
|
-
def __init__(self, lstsq: bool=False,
|
|
161
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
200
162
|
defaults=dict(lstsq=lstsq)
|
|
201
|
-
super().__init__(defaults=defaults,
|
|
163
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
202
164
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
lstsq = settings['lstsq']
|
|
165
|
+
@torch.no_grad
|
|
166
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
206
167
|
def f(x): return evaluate(x, 1)[1]
|
|
207
168
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
208
|
-
x_star = two_point_newton(x, f, f_j, lstsq)
|
|
169
|
+
x_star = two_point_newton(x, f, f_j, setting['lstsq'])
|
|
209
170
|
return x - x_star
|
|
210
171
|
|
|
211
172
|
#3f 2J 1inv
|
|
@@ -224,15 +185,14 @@ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
|
|
|
224
185
|
|
|
225
186
|
class SixthOrder3PM2(HigherOrderMethodBase):
|
|
226
187
|
"""Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
|
|
227
|
-
def __init__(self, lstsq: bool=False,
|
|
188
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
228
189
|
defaults=dict(lstsq=lstsq)
|
|
229
|
-
super().__init__(defaults=defaults,
|
|
190
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
230
191
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
lstsq = settings['lstsq']
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
234
194
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
235
195
|
def f(x): return evaluate(x, 1)[1]
|
|
236
|
-
x_star = sixth_order_3pm2(x, f, f_j, lstsq)
|
|
196
|
+
x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
|
|
237
197
|
return x - x_star
|
|
238
198
|
|