torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
|
|
2
|
-
import math
|
|
3
|
-
from typing import Literal, cast
|
|
1
|
+
|
|
4
2
|
from operator import itemgetter
|
|
3
|
+
from typing import Literal, cast
|
|
4
|
+
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import TensorList,
|
|
9
|
-
from ...
|
|
10
|
-
from ...utils.linalg.solve import cg, minres, find_within_trust_radius
|
|
7
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
8
|
+
from ...utils import TensorList, tofloat, unpack_dicts, unpack_states
|
|
9
|
+
from ...linalg.solve import cg, find_within_trust_radius, minres
|
|
11
10
|
from ..trust_region.trust_region import default_radius
|
|
12
11
|
|
|
13
|
-
|
|
12
|
+
|
|
13
|
+
class NewtonCG(Transform):
|
|
14
14
|
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
15
15
|
|
|
16
16
|
Notes:
|
|
@@ -37,17 +37,14 @@ class NewtonCG(Module):
|
|
|
37
37
|
hvp_method (str, optional):
|
|
38
38
|
Determines how Hessian-vector products are evaluated.
|
|
39
39
|
|
|
40
|
-
- ``"autograd"
|
|
41
|
-
|
|
42
|
-
- ``"
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
more accurate HVP approximation. This requires two extra
|
|
46
|
-
gradient evaluations.
|
|
47
|
-
Defaults to "autograd".
|
|
40
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop.
|
|
41
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
42
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
43
|
+
|
|
44
|
+
For NewtonCG ``"batched_autograd"`` is equivalent to ``"autograd"``. Defaults to ``"autograd"``.
|
|
48
45
|
h (float, optional):
|
|
49
|
-
The step size for finite
|
|
50
|
-
``"
|
|
46
|
+
The step size for finite difference if ``hvp_method`` is
|
|
47
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
51
48
|
warm_start (bool, optional):
|
|
52
49
|
If ``True``, the conjugate gradient solver is initialized with the
|
|
53
50
|
solution from the previous optimization step. This can accelerate
|
|
@@ -60,7 +57,7 @@ class NewtonCG(Module):
|
|
|
60
57
|
Newton-CG with a backtracking line search:
|
|
61
58
|
|
|
62
59
|
```python
|
|
63
|
-
opt = tz.
|
|
60
|
+
opt = tz.Optimizer(
|
|
64
61
|
model.parameters(),
|
|
65
62
|
tz.m.NewtonCG(),
|
|
66
63
|
tz.m.Backtracking()
|
|
@@ -69,7 +66,7 @@ class NewtonCG(Module):
|
|
|
69
66
|
|
|
70
67
|
Truncated Newton method (useful for large-scale problems):
|
|
71
68
|
```
|
|
72
|
-
opt = tz.
|
|
69
|
+
opt = tz.Optimizer(
|
|
73
70
|
model.parameters(),
|
|
74
71
|
tz.m.NewtonCG(maxiter=10),
|
|
75
72
|
tz.m.Backtracking()
|
|
@@ -82,100 +79,72 @@ class NewtonCG(Module):
|
|
|
82
79
|
maxiter: int | None = None,
|
|
83
80
|
tol: float = 1e-8,
|
|
84
81
|
reg: float = 1e-8,
|
|
85
|
-
hvp_method:
|
|
86
|
-
solver: Literal['cg', 'minres'
|
|
82
|
+
hvp_method: HVPMethod = "autograd",
|
|
83
|
+
solver: Literal['cg', 'minres'] = 'cg',
|
|
84
|
+
npc_terminate: bool = False,
|
|
87
85
|
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
88
86
|
miniter:int = 1,
|
|
89
87
|
warm_start=False,
|
|
88
|
+
warm_beta:float=0,
|
|
90
89
|
inner: Chainable | None = None,
|
|
91
90
|
):
|
|
92
91
|
defaults = locals().copy()
|
|
93
92
|
del defaults['self'], defaults['inner']
|
|
94
|
-
super().__init__(defaults,)
|
|
95
|
-
|
|
96
|
-
if inner is not None:
|
|
97
|
-
self.set_child('inner', inner)
|
|
93
|
+
super().__init__(defaults, inner=inner)
|
|
98
94
|
|
|
99
95
|
self._num_hvps = 0
|
|
100
96
|
self._num_hvps_last_step = 0
|
|
101
97
|
|
|
102
98
|
@torch.no_grad
|
|
103
|
-
def
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
settings = self.settings[params[0]]
|
|
109
|
-
tol = settings['tol']
|
|
110
|
-
reg = settings['reg']
|
|
111
|
-
maxiter = settings['maxiter']
|
|
112
|
-
hvp_method = settings['hvp_method']
|
|
113
|
-
solver = settings['solver'].lower().strip()
|
|
114
|
-
h = settings['h']
|
|
115
|
-
warm_start = settings['warm_start']
|
|
99
|
+
def update_states(self, objective, states, settings):
|
|
100
|
+
fs = settings[0]
|
|
101
|
+
hvp_method = fs['hvp_method']
|
|
102
|
+
h = fs['h']
|
|
116
103
|
|
|
117
|
-
self._num_hvps_last_step = 0
|
|
118
104
|
# ---------------------- Hessian vector product function --------------------- #
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def H_mm(x):
|
|
123
|
-
self._num_hvps_last_step += 1
|
|
124
|
-
with torch.enable_grad():
|
|
125
|
-
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
126
|
-
|
|
127
|
-
else:
|
|
128
|
-
|
|
129
|
-
with torch.enable_grad():
|
|
130
|
-
grad = var.get_grad()
|
|
131
|
-
|
|
132
|
-
if hvp_method == 'forward':
|
|
133
|
-
def H_mm(x):
|
|
134
|
-
self._num_hvps_last_step += 1
|
|
135
|
-
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
136
|
-
|
|
137
|
-
elif hvp_method == 'central':
|
|
138
|
-
def H_mm(x):
|
|
139
|
-
self._num_hvps_last_step += 1
|
|
140
|
-
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
141
|
-
|
|
142
|
-
else:
|
|
143
|
-
raise ValueError(hvp_method)
|
|
105
|
+
_, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
106
|
+
objective.temp = H_mv
|
|
144
107
|
|
|
108
|
+
@torch.no_grad
|
|
109
|
+
def apply_states(self, objective, states, settings):
|
|
110
|
+
self._num_hvps_last_step = 0
|
|
111
|
+
H_mv = objective.poptemp()
|
|
145
112
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
113
|
+
fs = settings[0]
|
|
114
|
+
tol = fs['tol']
|
|
115
|
+
reg = fs['reg']
|
|
116
|
+
maxiter = fs['maxiter']
|
|
117
|
+
solver = fs['solver'].lower().strip()
|
|
118
|
+
warm_start = fs['warm_start']
|
|
119
|
+
npc_terminate = fs["npc_terminate"]
|
|
151
120
|
|
|
152
121
|
# ---------------------------------- run cg ---------------------------------- #
|
|
153
122
|
x0 = None
|
|
154
|
-
if warm_start:
|
|
123
|
+
if warm_start:
|
|
124
|
+
x0 = unpack_states(states, objective.params, 'prev_x', cls=TensorList)
|
|
125
|
+
|
|
126
|
+
b = TensorList(objective.get_updates())
|
|
155
127
|
|
|
156
128
|
if solver == 'cg':
|
|
157
|
-
d, _ = cg(
|
|
129
|
+
d, _ = cg(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter,
|
|
130
|
+
miniter=fs["miniter"], reg=reg, npc_terminate=npc_terminate)
|
|
158
131
|
|
|
159
132
|
elif solver == 'minres':
|
|
160
|
-
d = minres(
|
|
161
|
-
|
|
162
|
-
elif solver == 'minres_npc':
|
|
163
|
-
d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
133
|
+
d = minres(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
|
|
164
134
|
|
|
165
135
|
else:
|
|
166
136
|
raise ValueError(f"Unknown solver {solver}")
|
|
167
137
|
|
|
168
138
|
if warm_start:
|
|
169
139
|
assert x0 is not None
|
|
170
|
-
x0.
|
|
171
|
-
|
|
172
|
-
var.update = d
|
|
140
|
+
x0.lerp_(d, weight = 1-fs["warm_beta"])
|
|
173
141
|
|
|
142
|
+
objective.updates = d
|
|
174
143
|
self._num_hvps += self._num_hvps_last_step
|
|
175
|
-
return
|
|
144
|
+
return objective
|
|
176
145
|
|
|
177
146
|
|
|
178
|
-
class NewtonCGSteihaug(
|
|
147
|
+
class NewtonCGSteihaug(Transform):
|
|
179
148
|
"""Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
|
|
180
149
|
|
|
181
150
|
Notes:
|
|
@@ -219,7 +188,7 @@ class NewtonCGSteihaug(Module):
|
|
|
219
188
|
whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
|
|
220
189
|
|
|
221
190
|
hvp_method (str, optional):
|
|
222
|
-
either "
|
|
191
|
+
either ``"fd_forward"`` to use forward formula which requires one backward pass per hessian-vector product, or ``"fd_central"`` to use a more accurate central formula which requires two backward passes. ``"fd_forward"`` is usually accurate enough. Defaults to ``"fd_forward"``.
|
|
223
192
|
h (float, optional): finite difference step size. Defaults to 1e-3.
|
|
224
193
|
|
|
225
194
|
inner (Chainable | None, optional):
|
|
@@ -229,7 +198,7 @@ class NewtonCGSteihaug(Module):
|
|
|
229
198
|
Trust-region Newton-CG:
|
|
230
199
|
|
|
231
200
|
```python
|
|
232
|
-
opt = tz.
|
|
201
|
+
opt = tz.Optimizer(
|
|
233
202
|
model.parameters(),
|
|
234
203
|
tz.m.NewtonCGSteihaug(),
|
|
235
204
|
)
|
|
@@ -261,7 +230,7 @@ class NewtonCGSteihaug(Module):
|
|
|
261
230
|
npc_terminate: bool = False,
|
|
262
231
|
|
|
263
232
|
# hvp settings
|
|
264
|
-
hvp_method: Literal["
|
|
233
|
+
hvp_method: Literal["fd_forward", "fd_central"] = "fd_central",
|
|
265
234
|
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
266
235
|
|
|
267
236
|
# inner
|
|
@@ -269,72 +238,51 @@ class NewtonCGSteihaug(Module):
|
|
|
269
238
|
):
|
|
270
239
|
defaults = locals().copy()
|
|
271
240
|
del defaults['self'], defaults['inner']
|
|
272
|
-
super().__init__(defaults,)
|
|
273
|
-
|
|
274
|
-
if inner is not None:
|
|
275
|
-
self.set_child('inner', inner)
|
|
241
|
+
super().__init__(defaults, inner=inner)
|
|
276
242
|
|
|
277
243
|
self._num_hvps = 0
|
|
278
244
|
self._num_hvps_last_step = 0
|
|
279
245
|
|
|
280
|
-
@torch.no_grad
|
|
281
|
-
def step(self, var):
|
|
282
|
-
params = TensorList(var.params)
|
|
283
|
-
closure = var.closure
|
|
284
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
285
|
-
|
|
286
|
-
tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
|
|
287
|
-
solver = self.defaults['solver'].lower().strip()
|
|
288
|
-
|
|
289
|
-
(reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
|
|
290
|
-
eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
|
|
291
|
-
miniter, max_history, adapt_tol) = itemgetter(
|
|
292
|
-
"reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
|
|
293
|
-
"eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
|
|
294
|
-
"miniter", "max_history", "adapt_tol",
|
|
295
|
-
)(self.defaults)
|
|
296
246
|
|
|
297
|
-
|
|
247
|
+
@torch.no_grad
|
|
248
|
+
def update_states(self, objective, states, settings):
|
|
249
|
+
fs = settings[0]
|
|
250
|
+
hvp_method = fs['hvp_method']
|
|
251
|
+
h = fs['h']
|
|
298
252
|
|
|
299
253
|
# ---------------------- Hessian vector product function --------------------- #
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
def H_mm(x):
|
|
304
|
-
self._num_hvps_last_step += 1
|
|
305
|
-
with torch.enable_grad():
|
|
306
|
-
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
307
|
-
|
|
308
|
-
else:
|
|
309
|
-
|
|
310
|
-
with torch.enable_grad():
|
|
311
|
-
grad = var.get_grad()
|
|
254
|
+
_, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
255
|
+
objective.temp = H_mv
|
|
312
256
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
257
|
+
@torch.no_grad
|
|
258
|
+
def apply_states(self, objective, states, settings):
|
|
259
|
+
self._num_hvps_last_step = 0
|
|
317
260
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
261
|
+
H_mv = objective.poptemp()
|
|
262
|
+
params = TensorList(objective.params)
|
|
263
|
+
fs = settings[0]
|
|
322
264
|
|
|
323
|
-
|
|
324
|
-
|
|
265
|
+
tol = fs['tol'] * self.global_state.get('tol_mul', 1)
|
|
266
|
+
solver = fs['solver'].lower().strip()
|
|
325
267
|
|
|
268
|
+
reg=fs["reg"]
|
|
269
|
+
maxiter=fs["maxiter"]
|
|
270
|
+
max_attempts=fs["max_attempts"]
|
|
271
|
+
init=fs["init"]
|
|
272
|
+
npc_terminate=fs["npc_terminate"]
|
|
273
|
+
miniter=fs["miniter"]
|
|
274
|
+
max_history=fs["max_history"]
|
|
275
|
+
adapt_tol=fs["adapt_tol"]
|
|
326
276
|
|
|
327
|
-
# -------------------------------- inner step -------------------------------- #
|
|
328
|
-
b = var.get_update()
|
|
329
|
-
if 'inner' in self.children:
|
|
330
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
331
|
-
b = as_tensorlist(b)
|
|
332
277
|
|
|
333
278
|
# ------------------------------- trust region ------------------------------- #
|
|
334
279
|
success = False
|
|
335
280
|
d = None
|
|
336
|
-
|
|
281
|
+
orig_params = [p.clone() for p in params]
|
|
282
|
+
b = TensorList(objective.get_updates())
|
|
337
283
|
solution = None
|
|
284
|
+
closure = objective.closure
|
|
285
|
+
assert closure is not None
|
|
338
286
|
|
|
339
287
|
while not success:
|
|
340
288
|
max_attempts -= 1
|
|
@@ -343,7 +291,7 @@ class NewtonCGSteihaug(Module):
|
|
|
343
291
|
trust_radius = self.global_state.get('trust_radius', init)
|
|
344
292
|
|
|
345
293
|
# -------------- make sure trust radius isn't too small or large ------------- #
|
|
346
|
-
finfo = torch.finfo(
|
|
294
|
+
finfo = torch.finfo(orig_params[0].dtype)
|
|
347
295
|
if trust_radius < finfo.tiny * 2:
|
|
348
296
|
trust_radius = self.global_state['trust_radius'] = init
|
|
349
297
|
if adapt_tol:
|
|
@@ -360,7 +308,7 @@ class NewtonCGSteihaug(Module):
|
|
|
360
308
|
if d is None:
|
|
361
309
|
if solver == 'cg':
|
|
362
310
|
d, solution = cg(
|
|
363
|
-
|
|
311
|
+
A_mv=H_mv,
|
|
364
312
|
b=b,
|
|
365
313
|
tol=tol,
|
|
366
314
|
maxiter=maxiter,
|
|
@@ -372,40 +320,40 @@ class NewtonCGSteihaug(Module):
|
|
|
372
320
|
)
|
|
373
321
|
|
|
374
322
|
elif solver == 'minres':
|
|
375
|
-
d = minres(
|
|
323
|
+
d = minres(A_mv=H_mv, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
|
|
376
324
|
|
|
377
325
|
else:
|
|
378
326
|
raise ValueError(f"unknown solver {solver}")
|
|
379
327
|
|
|
380
328
|
# ---------------------------- update trust radius --------------------------- #
|
|
381
329
|
self.global_state["trust_radius"], success = default_radius(
|
|
382
|
-
params=params,
|
|
383
|
-
closure=closure,
|
|
384
|
-
f=tofloat(
|
|
385
|
-
g=b,
|
|
386
|
-
H=
|
|
387
|
-
d=d,
|
|
388
|
-
trust_radius=trust_radius,
|
|
389
|
-
eta=eta,
|
|
390
|
-
nplus=nplus,
|
|
391
|
-
nminus=nminus,
|
|
392
|
-
rho_good=rho_good,
|
|
393
|
-
rho_bad=rho_bad,
|
|
394
|
-
boundary_tol=boundary_tol,
|
|
395
|
-
|
|
396
|
-
init=
|
|
397
|
-
state=
|
|
398
|
-
settings=
|
|
399
|
-
check_overflow=False, # this is checked manually to adapt tolerance
|
|
330
|
+
params = params,
|
|
331
|
+
closure = closure,
|
|
332
|
+
f = tofloat(objective.get_loss(False)),
|
|
333
|
+
g = b,
|
|
334
|
+
H = H_mv,
|
|
335
|
+
d = d,
|
|
336
|
+
trust_radius = trust_radius,
|
|
337
|
+
eta = fs["eta"],
|
|
338
|
+
nplus = fs["nplus"],
|
|
339
|
+
nminus = fs["nminus"],
|
|
340
|
+
rho_good = fs["rho_good"],
|
|
341
|
+
rho_bad = fs["rho_bad"],
|
|
342
|
+
boundary_tol = fs["boundary_tol"],
|
|
343
|
+
|
|
344
|
+
init = cast(int, None), # init isn't used because check_overflow=False
|
|
345
|
+
state = cast(dict, None), # not used
|
|
346
|
+
settings = cast(dict, None), # not used
|
|
347
|
+
check_overflow = False, # this is checked manually to adapt tolerance
|
|
400
348
|
)
|
|
401
349
|
|
|
402
350
|
# --------------------------- assign new direction --------------------------- #
|
|
403
351
|
assert d is not None
|
|
404
352
|
if success:
|
|
405
|
-
|
|
353
|
+
objective.updates = d
|
|
406
354
|
|
|
407
355
|
else:
|
|
408
|
-
|
|
356
|
+
objective.updates = params.zeros_like()
|
|
409
357
|
|
|
410
358
|
self._num_hvps += self._num_hvps_last_step
|
|
411
|
-
return
|
|
359
|
+
return objective
|