torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import Module, Chainable,
|
|
3
|
+
from ...core import Module, Chainable, step
|
|
4
4
|
from ...utils import TensorList, vec_to_tensors
|
|
5
5
|
from ..second_order.newton import _newton_step, _get_H
|
|
6
6
|
|
|
@@ -58,12 +58,12 @@ class SG2(Module):
|
|
|
58
58
|
if inner is not None: self.set_child('inner', inner)
|
|
59
59
|
|
|
60
60
|
@torch.no_grad
|
|
61
|
-
def update(self,
|
|
61
|
+
def update(self, objective):
|
|
62
62
|
k = self.global_state.get('step', 0) + 1
|
|
63
63
|
self.global_state["step"] = k
|
|
64
64
|
|
|
65
|
-
params = TensorList(
|
|
66
|
-
closure =
|
|
65
|
+
params = TensorList(objective.params)
|
|
66
|
+
closure = objective.closure
|
|
67
67
|
if closure is None:
|
|
68
68
|
raise RuntimeError("closure is required for SG2")
|
|
69
69
|
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
@@ -79,7 +79,7 @@ class SG2(Module):
|
|
|
79
79
|
|
|
80
80
|
# one sided
|
|
81
81
|
if self.defaults["one_sided"]:
|
|
82
|
-
g_0 = TensorList(
|
|
82
|
+
g_0 = TensorList(objective.get_grads())
|
|
83
83
|
params.add_(cd)
|
|
84
84
|
closure()
|
|
85
85
|
|
|
@@ -126,9 +126,9 @@ class SG2(Module):
|
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
@torch.no_grad
|
|
129
|
-
def apply(self,
|
|
129
|
+
def apply(self, objective):
|
|
130
130
|
dir = _newton_step(
|
|
131
|
-
|
|
131
|
+
objective=objective,
|
|
132
132
|
H = self.global_state["H"],
|
|
133
133
|
damping = self.defaults["damping"],
|
|
134
134
|
inner = self.children.get("inner", None),
|
|
@@ -138,10 +138,10 @@ class SG2(Module):
|
|
|
138
138
|
g_proj=None,
|
|
139
139
|
)
|
|
140
140
|
|
|
141
|
-
|
|
142
|
-
return
|
|
141
|
+
objective.updates = vec_to_tensors(dir, objective.params)
|
|
142
|
+
return objective
|
|
143
143
|
|
|
144
|
-
def get_H(self,
|
|
144
|
+
def get_H(self,objective=...):
|
|
145
145
|
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
146
146
|
|
|
147
147
|
|
|
@@ -198,12 +198,12 @@ class SPSA2(Module):
|
|
|
198
198
|
if inner is not None: self.set_child('inner', inner)
|
|
199
199
|
|
|
200
200
|
@torch.no_grad
|
|
201
|
-
def update(self,
|
|
201
|
+
def update(self, objective):
|
|
202
202
|
k = self.global_state.get('step', 0) + 1
|
|
203
203
|
self.global_state["step"] = k
|
|
204
204
|
|
|
205
|
-
params = TensorList(
|
|
206
|
-
closure =
|
|
205
|
+
params = TensorList(objective.params)
|
|
206
|
+
closure = objective.closure
|
|
207
207
|
if closure is None:
|
|
208
208
|
raise RuntimeError("closure is required for SPSA2")
|
|
209
209
|
|
|
@@ -260,7 +260,7 @@ class SPSA2(Module):
|
|
|
260
260
|
H_hat /= n_samples
|
|
261
261
|
|
|
262
262
|
# set grad to approximated grad
|
|
263
|
-
|
|
263
|
+
objective.grads = g_0
|
|
264
264
|
|
|
265
265
|
# update H
|
|
266
266
|
H = self.global_state.get("H", None)
|
|
@@ -273,9 +273,9 @@ class SPSA2(Module):
|
|
|
273
273
|
self.global_state["H"] = H
|
|
274
274
|
|
|
275
275
|
@torch.no_grad
|
|
276
|
-
def apply(self,
|
|
276
|
+
def apply(self, objective):
|
|
277
277
|
dir = _newton_step(
|
|
278
|
-
|
|
278
|
+
objective=objective,
|
|
279
279
|
H = self.global_state["H"],
|
|
280
280
|
damping = self.defaults["damping"],
|
|
281
281
|
inner = self.children.get("inner", None),
|
|
@@ -285,8 +285,8 @@ class SPSA2(Module):
|
|
|
285
285
|
g_proj=None,
|
|
286
286
|
)
|
|
287
287
|
|
|
288
|
-
|
|
289
|
-
return
|
|
288
|
+
objective.updates = vec_to_tensors(dir, objective.params)
|
|
289
|
+
return objective
|
|
290
290
|
|
|
291
|
-
def get_H(self,
|
|
291
|
+
def get_H(self,objective=...):
|
|
292
292
|
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
@@ -4,12 +4,14 @@ from typing import final, Literal, cast
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
7
|
+
from ...core import Chainable, Module, Objective
|
|
8
8
|
from ...utils import TensorList
|
|
9
9
|
from ..termination import TerminationCriteriaBase
|
|
10
10
|
|
|
11
11
|
def _reset_except_self(optimizer, var, self: Module):
|
|
12
|
-
for m in optimizer.unrolled_modules:
|
|
12
|
+
for m in optimizer.unrolled_modules:
|
|
13
|
+
if m is not self:
|
|
14
|
+
m.reset()
|
|
13
15
|
|
|
14
16
|
class RestartStrategyBase(Module, ABC):
|
|
15
17
|
"""Base class for restart strategies.
|
|
@@ -24,7 +26,7 @@ class RestartStrategyBase(Module, ABC):
|
|
|
24
26
|
self.set_child('modules', modules)
|
|
25
27
|
|
|
26
28
|
@abstractmethod
|
|
27
|
-
def should_reset(self, var:
|
|
29
|
+
def should_reset(self, var: Objective) -> bool:
|
|
28
30
|
"""returns whether reset should occur"""
|
|
29
31
|
|
|
30
32
|
def _reset_on_condition(self, var):
|
|
@@ -39,23 +41,23 @@ class RestartStrategyBase(Module, ABC):
|
|
|
39
41
|
return modules
|
|
40
42
|
|
|
41
43
|
@final
|
|
42
|
-
def update(self,
|
|
43
|
-
modules = self._reset_on_condition(
|
|
44
|
+
def update(self, objective):
|
|
45
|
+
modules = self._reset_on_condition(objective)
|
|
44
46
|
if modules is not None:
|
|
45
|
-
modules.update(
|
|
47
|
+
modules.update(objective)
|
|
46
48
|
|
|
47
49
|
@final
|
|
48
|
-
def apply(self,
|
|
50
|
+
def apply(self, objective):
|
|
49
51
|
# don't check here because it was check in `update`
|
|
50
52
|
modules = self.children.get('modules', None)
|
|
51
|
-
if modules is None: return
|
|
52
|
-
return modules.apply(
|
|
53
|
+
if modules is None: return objective
|
|
54
|
+
return modules.apply(objective.clone(clone_updates=False))
|
|
53
55
|
|
|
54
56
|
@final
|
|
55
|
-
def step(self,
|
|
56
|
-
modules = self._reset_on_condition(
|
|
57
|
-
if modules is None: return
|
|
58
|
-
return modules.step(
|
|
57
|
+
def step(self, objective):
|
|
58
|
+
modules = self._reset_on_condition(objective)
|
|
59
|
+
if modules is None: return objective
|
|
60
|
+
return modules.step(objective.clone(clone_updates=False))
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
|
|
@@ -170,7 +172,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
170
172
|
super().__init__(defaults, modules)
|
|
171
173
|
|
|
172
174
|
def should_reset(self, var):
|
|
173
|
-
g = TensorList(var.
|
|
175
|
+
g = TensorList(var.get_grads())
|
|
174
176
|
cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
|
|
175
177
|
|
|
176
178
|
# -------------------------------- initialize -------------------------------- #
|
|
@@ -192,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
192
194
|
|
|
193
195
|
# ------------------------------- 2nd condition ------------------------------ #
|
|
194
196
|
if (cond2 is not None) and (not reset):
|
|
195
|
-
d_g = TensorList(var.
|
|
197
|
+
d_g = TensorList(var.get_updates()).dot(g)
|
|
196
198
|
if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
|
|
197
199
|
reset = True
|
|
198
200
|
|
|
@@ -229,17 +231,17 @@ class BirginMartinezRestart(Module):
|
|
|
229
231
|
|
|
230
232
|
self.set_child("module", module)
|
|
231
233
|
|
|
232
|
-
def update(self,
|
|
234
|
+
def update(self, objective):
|
|
233
235
|
module = self.children['module']
|
|
234
|
-
module.update(
|
|
236
|
+
module.update(objective)
|
|
235
237
|
|
|
236
|
-
def apply(self,
|
|
238
|
+
def apply(self, objective):
|
|
237
239
|
module = self.children['module']
|
|
238
|
-
|
|
240
|
+
objective = module.apply(objective.clone(clone_updates=False))
|
|
239
241
|
|
|
240
242
|
cond = self.defaults['cond']
|
|
241
|
-
g = TensorList(
|
|
242
|
-
d = TensorList(
|
|
243
|
+
g = TensorList(objective.get_grads())
|
|
244
|
+
d = TensorList(objective.get_updates())
|
|
243
245
|
d_g = d.dot(g)
|
|
244
246
|
d_norm = d.global_vector_norm()
|
|
245
247
|
g_norm = g.global_vector_norm()
|
|
@@ -247,7 +249,7 @@ class BirginMartinezRestart(Module):
|
|
|
247
249
|
# d in our case is same direction as g so it has a minus sign
|
|
248
250
|
if -d_g > -cond * d_norm * g_norm:
|
|
249
251
|
module.reset()
|
|
250
|
-
|
|
251
|
-
return
|
|
252
|
+
objective.updates = g.clone()
|
|
253
|
+
return objective
|
|
252
254
|
|
|
253
|
-
return
|
|
255
|
+
return objective
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from .ifn import InverseFreeNewton
|
|
2
|
-
from .inm import
|
|
2
|
+
from .inm import ImprovedNewton
|
|
3
3
|
from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
|
|
4
4
|
from .newton import Newton
|
|
5
5
|
from .newton_cg import NewtonCG, NewtonCGSteihaug
|
|
6
6
|
from .nystrom import NystromPCG, NystromSketchAndSolve
|
|
7
|
-
from .rsn import
|
|
7
|
+
from .rsn import SubspaceNewton
|
|
@@ -1,89 +1,58 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from collections.abc import Callable
|
|
3
|
-
from functools import partial
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
1
|
import torch
|
|
7
2
|
|
|
8
|
-
from ...core import Chainable,
|
|
3
|
+
from ...core import Chainable, Transform, HessianMethod
|
|
9
4
|
from ...utils import TensorList, vec_to_tensors
|
|
10
|
-
from ...
|
|
11
|
-
from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
|
|
5
|
+
from ...linalg.linear_operator import DenseWithInverse
|
|
12
6
|
|
|
13
7
|
|
|
14
|
-
class InverseFreeNewton(
|
|
8
|
+
class InverseFreeNewton(Transform):
|
|
15
9
|
"""Inverse-free newton's method
|
|
16
10
|
|
|
17
|
-
.. note::
|
|
18
|
-
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
19
|
-
|
|
20
|
-
.. note::
|
|
21
|
-
This module requires the a closure passed to the optimizer step,
|
|
22
|
-
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
23
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
24
|
-
|
|
25
|
-
.. warning::
|
|
26
|
-
this uses roughly O(N^2) memory.
|
|
27
|
-
|
|
28
11
|
Reference
|
|
29
12
|
[Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
|
|
30
13
|
"""
|
|
31
14
|
def __init__(
|
|
32
15
|
self,
|
|
33
16
|
update_freq: int = 1,
|
|
34
|
-
hessian_method:
|
|
35
|
-
|
|
17
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
18
|
+
h: float = 1e-3,
|
|
36
19
|
inner: Chainable | None = None,
|
|
37
20
|
):
|
|
38
|
-
defaults = dict(hessian_method=hessian_method,
|
|
39
|
-
super().__init__(defaults)
|
|
40
|
-
|
|
41
|
-
if inner is not None:
|
|
42
|
-
self.set_child('inner', inner)
|
|
21
|
+
defaults = dict(hessian_method=hessian_method, h=h)
|
|
22
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
43
23
|
|
|
44
24
|
@torch.no_grad
|
|
45
|
-
def
|
|
46
|
-
|
|
25
|
+
def update_states(self, objective, states, settings):
|
|
26
|
+
fs = settings[0]
|
|
47
27
|
|
|
48
|
-
|
|
49
|
-
|
|
28
|
+
_, _, H = objective.hessian(
|
|
29
|
+
hessian_method=fs['hessian_method'],
|
|
30
|
+
h=fs['h'],
|
|
31
|
+
at_x0=True
|
|
32
|
+
)
|
|
50
33
|
|
|
51
|
-
|
|
52
|
-
loss, g_list, H = _get_loss_grad_and_hessian(
|
|
53
|
-
var, self.defaults['hessian_method'], self.defaults['vectorize']
|
|
54
|
-
)
|
|
55
|
-
self.global_state["H"] = H
|
|
34
|
+
self.global_state["H"] = H
|
|
56
35
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
36
|
+
# inverse free part
|
|
37
|
+
if 'Y' not in self.global_state:
|
|
38
|
+
num = H.T
|
|
39
|
+
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
61
40
|
|
|
62
|
-
|
|
63
|
-
|
|
41
|
+
finfo = torch.finfo(H.dtype)
|
|
42
|
+
self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
64
43
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
44
|
+
else:
|
|
45
|
+
Y = self.global_state['Y']
|
|
46
|
+
I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
47
|
+
I2 -= H @ Y
|
|
48
|
+
self.global_state['Y'] = Y @ I2
|
|
70
49
|
|
|
71
50
|
|
|
72
|
-
def
|
|
51
|
+
def apply_states(self, objective, states, settings):
|
|
73
52
|
Y = self.global_state["Y"]
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
update = var.get_update()
|
|
78
|
-
if 'inner' in self.children:
|
|
79
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
80
|
-
|
|
81
|
-
g = torch.cat([t.ravel() for t in update])
|
|
82
|
-
|
|
83
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
84
|
-
var.update = vec_to_tensors(Y@g, params)
|
|
85
|
-
|
|
86
|
-
return var
|
|
53
|
+
g = torch.cat([t.ravel() for t in objective.get_updates()])
|
|
54
|
+
objective.updates = vec_to_tensors(Y@g, objective.params)
|
|
55
|
+
return objective
|
|
87
56
|
|
|
88
|
-
def get_H(self,
|
|
57
|
+
def get_H(self,objective=...):
|
|
89
58
|
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
|
-
from typing import Literal
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
6
|
-
from ...core import Chainable,
|
|
7
|
-
from ...utils import TensorList, vec_to_tensors
|
|
5
|
+
from ...core import Chainable, Transform, HessianMethod
|
|
6
|
+
from ...utils import TensorList, vec_to_tensors, unpack_states
|
|
8
7
|
from ..functional import safe_clip
|
|
9
|
-
from .newton import _get_H,
|
|
8
|
+
from .newton import _get_H, _newton_step
|
|
10
9
|
|
|
11
10
|
@torch.no_grad
|
|
12
11
|
def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
@@ -25,7 +24,7 @@ def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
|
|
|
25
24
|
L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
|
|
26
25
|
return (Q * L.unsqueeze(-2)) @ Q.mH
|
|
27
26
|
|
|
28
|
-
class
|
|
27
|
+
class ImprovedNewton(Transform):
|
|
29
28
|
"""Improved Newton's Method (INM).
|
|
30
29
|
|
|
31
30
|
Reference:
|
|
@@ -37,69 +36,66 @@ class INM(Module):
|
|
|
37
36
|
damping: float = 0,
|
|
38
37
|
use_lstsq: bool = False,
|
|
39
38
|
update_freq: int = 1,
|
|
40
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
41
|
-
vectorize: bool = True,
|
|
42
|
-
inner: Chainable | None = None,
|
|
43
39
|
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
44
40
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
41
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
42
|
+
h: float = 1e-3,
|
|
43
|
+
inner: Chainable | None = None,
|
|
45
44
|
):
|
|
46
|
-
defaults =
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
if inner is not None:
|
|
50
|
-
self.set_child("inner", inner)
|
|
45
|
+
defaults = locals().copy()
|
|
46
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
47
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner, )
|
|
51
48
|
|
|
52
49
|
@torch.no_grad
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
if step == 0:
|
|
73
|
-
x_prev.copy_(x_list)
|
|
74
|
-
f_prev.copy_(f_list)
|
|
75
|
-
self.global_state["P"] = J
|
|
76
|
-
return
|
|
77
|
-
|
|
78
|
-
# INM update
|
|
79
|
-
s_list = x_list - x_prev
|
|
80
|
-
y_list = f_list - f_prev
|
|
50
|
+
def update_states(self, objective, states, settings):
|
|
51
|
+
fs = settings[0]
|
|
52
|
+
|
|
53
|
+
_, f_list, J = objective.hessian(
|
|
54
|
+
hessian_method=fs['hessian_method'],
|
|
55
|
+
h=fs['h'],
|
|
56
|
+
at_x0=True
|
|
57
|
+
)
|
|
58
|
+
if f_list is None: f_list = objective.get_grads()
|
|
59
|
+
|
|
60
|
+
f = torch.cat([t.ravel() for t in f_list])
|
|
61
|
+
J = _eigval_fn(J, fs["eigval_fn"])
|
|
62
|
+
|
|
63
|
+
x_list = TensorList(objective.params)
|
|
64
|
+
f_list = TensorList(objective.get_grads())
|
|
65
|
+
x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
|
|
66
|
+
|
|
67
|
+
# initialize on 1st step, do Newton step
|
|
68
|
+
if "P" not in self.global_state:
|
|
81
69
|
x_prev.copy_(x_list)
|
|
82
70
|
f_prev.copy_(f_list)
|
|
71
|
+
self.global_state["P"] = J
|
|
72
|
+
return
|
|
83
73
|
|
|
84
|
-
|
|
74
|
+
# INM update
|
|
75
|
+
s_list = x_list - x_prev
|
|
76
|
+
y_list = f_list - f_prev
|
|
77
|
+
x_prev.copy_(x_list)
|
|
78
|
+
f_prev.copy_(f_list)
|
|
79
|
+
|
|
80
|
+
self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
|
|
85
81
|
|
|
86
82
|
|
|
87
83
|
@torch.no_grad
|
|
88
|
-
def
|
|
89
|
-
|
|
84
|
+
def apply_states(self, objective, states, settings):
|
|
85
|
+
fs = settings[0]
|
|
86
|
+
|
|
90
87
|
update = _newton_step(
|
|
91
|
-
|
|
88
|
+
objective = objective,
|
|
92
89
|
H = self.global_state["P"],
|
|
93
|
-
damping=
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
90
|
+
damping = fs["damping"],
|
|
91
|
+
H_tfm = fs["H_tfm"],
|
|
92
|
+
eigval_fn = None, # it is applied in `update`
|
|
93
|
+
use_lstsq = fs["use_lstsq"],
|
|
98
94
|
)
|
|
99
95
|
|
|
100
|
-
|
|
96
|
+
objective.updates = vec_to_tensors(update, objective.params)
|
|
101
97
|
|
|
102
|
-
return
|
|
98
|
+
return objective
|
|
103
99
|
|
|
104
|
-
def get_H(self,
|
|
100
|
+
def get_H(self,objective=...):
|
|
105
101
|
return _get_H(self.global_state["P"], eigval_fn=None)
|