torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -13,7 +13,7 @@ import torch
|
|
|
13
13
|
from ...core import Chainable, Module, apply_transform
|
|
14
14
|
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
15
|
from ...utils.derivatives import (
|
|
16
|
-
|
|
16
|
+
flatten_jacobian,
|
|
17
17
|
jacobian_wrt,
|
|
18
18
|
)
|
|
19
19
|
|
|
@@ -70,57 +70,94 @@ def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
|
70
70
|
def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
|
|
71
71
|
derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
|
|
72
72
|
x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
|
|
73
|
-
bounds = None
|
|
74
|
-
if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
75
73
|
|
|
76
|
-
#
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
74
|
+
# notes
|
|
75
|
+
# 1. since we have exact hessian we use trust methods
|
|
76
|
+
|
|
77
|
+
# 2. if len(derivatives) is 1, only gradient is available,
|
|
78
|
+
# thus use slsqp depending on whether trust region is enabled
|
|
79
|
+
# this is just so that I can test that trust region works
|
|
80
|
+
if trust_region is None:
|
|
81
|
+
if len(derivatives) == 1: raise RuntimeError("trust region must be enabled because 1st order has no minima")
|
|
82
|
+
method = 'trust-exact'
|
|
83
|
+
de_bounds = list(zip(x0 - 10, x0 + 10))
|
|
84
|
+
constraints = None
|
|
85
|
+
|
|
80
86
|
else:
|
|
81
|
-
if len(derivatives) == 1: method = '
|
|
87
|
+
if len(derivatives) == 1: method = 'slsqp'
|
|
82
88
|
else: method = 'trust-constr'
|
|
89
|
+
de_bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
90
|
+
|
|
91
|
+
def l2_bound_f(x):
|
|
92
|
+
if x.ndim == 2: return np.sum((x - x0[:,None])**2, axis=0)[None,:] # DE passes (ndim, batch_size) and expects (M, S)
|
|
93
|
+
return np.sum((x - x0)**2, axis=0)
|
|
94
|
+
|
|
95
|
+
def l2_bound_g(x):
|
|
96
|
+
return 2 * (x - x0)
|
|
97
|
+
|
|
98
|
+
def l2_bound_h(x, v):
|
|
99
|
+
return v[0] * 2 * np.eye(x0.shape[0])
|
|
100
|
+
|
|
101
|
+
constraint = scipy.optimize.NonlinearConstraint(
|
|
102
|
+
fun=l2_bound_f,
|
|
103
|
+
lb=0, # 0 <= ||x-x0||^2
|
|
104
|
+
ub=trust_region**2, # ||x-x0||^2 <= R^2
|
|
105
|
+
jac=l2_bound_g, # pyright:ignore[reportArgumentType]
|
|
106
|
+
hess=l2_bound_h,
|
|
107
|
+
keep_feasible=False
|
|
108
|
+
)
|
|
109
|
+
constraints = [constraint]
|
|
83
110
|
|
|
84
111
|
x_init = x0.copy()
|
|
85
112
|
v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
|
|
113
|
+
|
|
114
|
+
# ---------------------------------- run DE ---------------------------------- #
|
|
86
115
|
if de_iters is not None and de_iters != 0:
|
|
87
116
|
if de_iters == -1: de_iters = None # let scipy decide
|
|
117
|
+
|
|
118
|
+
# DE needs bounds so use linf ig
|
|
88
119
|
res = scipy.optimize.differential_evolution(
|
|
89
120
|
_proximal_poly_v,
|
|
90
|
-
|
|
121
|
+
de_bounds,
|
|
91
122
|
args=(c, prox, x0.copy(), derivatives),
|
|
92
123
|
maxiter=de_iters,
|
|
93
124
|
vectorized=True,
|
|
125
|
+
constraints = constraints,
|
|
126
|
+
updating='deferred',
|
|
94
127
|
)
|
|
95
|
-
if res.fun < v0: x_init = res.x
|
|
96
|
-
|
|
97
|
-
res = scipy.optimize.minimize(
|
|
98
|
-
_proximal_poly_v,
|
|
99
|
-
x_init,
|
|
100
|
-
method=method,
|
|
101
|
-
args=(c, prox, x0.copy(), derivatives),
|
|
102
|
-
jac=_proximal_poly_g,
|
|
103
|
-
hess=_proximal_poly_H,
|
|
104
|
-
bounds=bounds
|
|
105
|
-
)
|
|
128
|
+
if res.fun < v0 and np.all(np.isfinite(res.x)): x_init = res.x
|
|
106
129
|
|
|
130
|
+
# ------------------------------- run minimize ------------------------------- #
|
|
131
|
+
try:
|
|
132
|
+
res = scipy.optimize.minimize(
|
|
133
|
+
_proximal_poly_v,
|
|
134
|
+
x_init,
|
|
135
|
+
method=method,
|
|
136
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
137
|
+
jac=_proximal_poly_g,
|
|
138
|
+
hess=_proximal_poly_H,
|
|
139
|
+
constraints = constraints,
|
|
140
|
+
)
|
|
141
|
+
except ValueError:
|
|
142
|
+
return x, -float('inf')
|
|
107
143
|
return torch.from_numpy(res.x).to(x), res.fun
|
|
108
144
|
|
|
109
145
|
|
|
110
146
|
|
|
111
147
|
class HigherOrderNewton(Module):
|
|
112
|
-
"""
|
|
113
|
-
A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
114
|
-
It is recommended to enable at least one of trust region or proximal penalty.
|
|
148
|
+
"""A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
115
149
|
|
|
116
150
|
This constructs an nth order taylor approximation via autograd and minimizes it with
|
|
117
|
-
scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
|
|
151
|
+
``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.
|
|
118
152
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
153
|
+
The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
|
|
154
|
+
so it can be more efficient in very specific instances.
|
|
155
|
+
|
|
156
|
+
Notes:
|
|
157
|
+
- In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
158
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
|
|
159
|
+
- this uses roughly O(N^order) memory and solving the subproblem is very expensive.
|
|
160
|
+
- "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
|
|
124
161
|
|
|
125
162
|
Args:
|
|
126
163
|
|
|
@@ -136,7 +173,7 @@ class HigherOrderNewton(Module):
|
|
|
136
173
|
increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
|
|
137
174
|
decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
|
|
138
175
|
trust_init (float | None, optional):
|
|
139
|
-
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on
|
|
176
|
+
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
|
|
140
177
|
trust_tol (float, optional):
|
|
141
178
|
Maximum ratio of expected loss reduction to actual reduction for trust region increase.
|
|
142
179
|
Should 1 or higer. Defaults to 2.
|
|
@@ -149,38 +186,43 @@ class HigherOrderNewton(Module):
|
|
|
149
186
|
self,
|
|
150
187
|
order: int = 4,
|
|
151
188
|
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
189
|
+
nplus: float = 3.5,
|
|
190
|
+
nminus: float = 0.25,
|
|
191
|
+
rho_good: float = 0.99,
|
|
192
|
+
rho_bad: float = 1e-4,
|
|
193
|
+
init: float | None = None,
|
|
194
|
+
eta: float = 1e-6,
|
|
195
|
+
max_attempts = 10,
|
|
196
|
+
boundary_tol: float = 1e-2,
|
|
156
197
|
de_iters: int | None = None,
|
|
157
198
|
vectorize: bool = True,
|
|
158
199
|
):
|
|
159
|
-
if
|
|
160
|
-
if trust_method == 'bounds':
|
|
161
|
-
else:
|
|
200
|
+
if init is None:
|
|
201
|
+
if trust_method == 'bounds': init = 1
|
|
202
|
+
else: init = 0.1
|
|
162
203
|
|
|
163
|
-
defaults = dict(order=order, trust_method=trust_method,
|
|
204
|
+
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
|
|
164
205
|
super().__init__(defaults)
|
|
165
206
|
|
|
166
207
|
@torch.no_grad
|
|
167
208
|
def step(self, var):
|
|
168
209
|
params = TensorList(var.params)
|
|
169
210
|
closure = var.closure
|
|
170
|
-
if closure is None: raise RuntimeError('
|
|
211
|
+
if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
|
|
171
212
|
|
|
172
213
|
settings = self.settings[params[0]]
|
|
173
214
|
order = settings['order']
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
215
|
+
nplus = settings['nplus']
|
|
216
|
+
nminus = settings['nminus']
|
|
217
|
+
eta = settings['eta']
|
|
218
|
+
init = settings['init']
|
|
178
219
|
trust_method = settings['trust_method']
|
|
179
220
|
de_iters = settings['de_iters']
|
|
221
|
+
max_attempts = settings['max_attempts']
|
|
180
222
|
vectorize = settings['vectorize']
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
223
|
+
boundary_tol = settings['boundary_tol']
|
|
224
|
+
rho_good = settings['rho_good']
|
|
225
|
+
rho_bad = settings['rho_bad']
|
|
184
226
|
|
|
185
227
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
186
228
|
with torch.enable_grad():
|
|
@@ -200,57 +242,86 @@ class HigherOrderNewton(Module):
|
|
|
200
242
|
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
201
243
|
with torch.no_grad() if is_last else nullcontext():
|
|
202
244
|
# the shape is (ndim, ) * order
|
|
203
|
-
T =
|
|
245
|
+
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
204
246
|
derivatives.append(T)
|
|
205
247
|
|
|
206
248
|
x0 = torch.cat([p.ravel() for p in params])
|
|
207
249
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
250
|
+
success = False
|
|
251
|
+
x_star = None
|
|
252
|
+
while not success:
|
|
253
|
+
max_attempts -= 1
|
|
254
|
+
if max_attempts < 0: break
|
|
255
|
+
|
|
256
|
+
# load trust region value
|
|
257
|
+
trust_value = self.global_state.get('trust_region', init)
|
|
258
|
+
|
|
259
|
+
# make sure its not too small or too large
|
|
260
|
+
finfo = torch.finfo(x0.dtype)
|
|
261
|
+
if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
|
|
262
|
+
trust_value = self.global_state['trust_region'] = settings['init']
|
|
263
|
+
|
|
264
|
+
# determine tr and prox values
|
|
265
|
+
if trust_method is None: trust_method = 'none'
|
|
266
|
+
else: trust_method = trust_method.lower()
|
|
267
|
+
|
|
268
|
+
if trust_method == 'none':
|
|
269
|
+
trust_region = None
|
|
270
|
+
prox = 0
|
|
271
|
+
|
|
272
|
+
elif trust_method == 'bounds':
|
|
273
|
+
trust_region = trust_value
|
|
274
|
+
prox = 0
|
|
275
|
+
|
|
276
|
+
elif trust_method == 'proximal':
|
|
277
|
+
trust_region = None
|
|
278
|
+
prox = 1 / trust_value
|
|
279
|
+
|
|
280
|
+
else:
|
|
281
|
+
raise ValueError(trust_method)
|
|
282
|
+
|
|
283
|
+
# minimize the model
|
|
284
|
+
x_star, expected_loss = _poly_minimize(
|
|
285
|
+
trust_region=trust_region,
|
|
286
|
+
prox=prox,
|
|
287
|
+
de_iters=de_iters,
|
|
288
|
+
c=loss.item(),
|
|
289
|
+
x=x0,
|
|
290
|
+
derivatives=derivatives,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# update trust region
|
|
294
|
+
if trust_method == 'none':
|
|
295
|
+
success = True
|
|
296
|
+
else:
|
|
297
|
+
pred_reduction = loss - expected_loss
|
|
298
|
+
|
|
299
|
+
vec_to_tensors_(x_star, params)
|
|
300
|
+
loss_star = closure(False)
|
|
301
|
+
vec_to_tensors_(x0, params)
|
|
302
|
+
reduction = loss - loss_star
|
|
303
|
+
|
|
304
|
+
rho = reduction / (max(pred_reduction, 1e-8))
|
|
305
|
+
# failed step
|
|
306
|
+
if rho < rho_bad:
|
|
307
|
+
self.global_state['trust_region'] = trust_value * nminus
|
|
308
|
+
|
|
309
|
+
# very good step
|
|
310
|
+
elif rho > rho_good:
|
|
311
|
+
step = (x_star - x0)
|
|
312
|
+
magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
|
|
313
|
+
if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
|
|
314
|
+
# close to boundary
|
|
315
|
+
self.global_state['trust_region'] = trust_value * nplus
|
|
316
|
+
|
|
317
|
+
# if the ratio is high enough then accept the proposed step
|
|
318
|
+
success = rho > eta
|
|
319
|
+
|
|
320
|
+
assert x_star is not None
|
|
321
|
+
if success:
|
|
322
|
+
difference = vec_to_tensors(x0 - x_star, params)
|
|
323
|
+
var.update = list(difference)
|
|
223
324
|
else:
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
x_star, expected_loss = _poly_minimize(
|
|
227
|
-
trust_region=trust_region,
|
|
228
|
-
prox=prox,
|
|
229
|
-
de_iters=de_iters,
|
|
230
|
-
c=loss.item(),
|
|
231
|
-
x=x0,
|
|
232
|
-
derivatives=derivatives,
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
# trust region
|
|
236
|
-
if trust_method != 'none':
|
|
237
|
-
expected_reduction = loss - expected_loss
|
|
238
|
-
|
|
239
|
-
vec_to_tensors_(x_star, params)
|
|
240
|
-
loss_star = closure(False)
|
|
241
|
-
vec_to_tensors_(x0, params)
|
|
242
|
-
reduction = loss - loss_star
|
|
243
|
-
|
|
244
|
-
# failed step
|
|
245
|
-
if reduction <= 0:
|
|
246
|
-
x_star = x0
|
|
247
|
-
self.global_state['trust_value'] = trust_value * decrease
|
|
248
|
-
|
|
249
|
-
# very good step
|
|
250
|
-
elif expected_reduction / reduction <= trust_tol:
|
|
251
|
-
self.global_state['trust_value'] = trust_value * increase
|
|
252
|
-
|
|
253
|
-
difference = vec_to_tensors(x0 - x_star, params)
|
|
254
|
-
var.update = list(difference)
|
|
325
|
+
var.update = params.zeros_like()
|
|
255
326
|
return var
|
|
256
327
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .gn import SumOfSquares, GaussNewton
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Module
|
|
3
|
+
|
|
4
|
+
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
|
+
from ...utils import vec_to_tensors
|
|
6
|
+
from ...utils.linalg import linear_operator
|
|
7
|
+
class SumOfSquares(Module):
|
|
8
|
+
"""Sets loss to be the sum of squares of values returned by the closure.
|
|
9
|
+
|
|
10
|
+
This is meant to be used to test least squares methods against ordinary minimization methods.
|
|
11
|
+
|
|
12
|
+
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
13
|
+
Please add the `backward` argument, it will always be False but it is required.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
@torch.no_grad
|
|
19
|
+
def step(self, var):
|
|
20
|
+
closure = var.closure
|
|
21
|
+
|
|
22
|
+
if closure is not None:
|
|
23
|
+
def sos_closure(backward=True):
|
|
24
|
+
if backward:
|
|
25
|
+
var.zero_grad()
|
|
26
|
+
with torch.enable_grad():
|
|
27
|
+
loss = closure(False)
|
|
28
|
+
loss = loss.pow(2).sum()
|
|
29
|
+
loss.backward()
|
|
30
|
+
return loss
|
|
31
|
+
|
|
32
|
+
loss = closure(False)
|
|
33
|
+
return loss.pow(2).sum()
|
|
34
|
+
|
|
35
|
+
var.closure = sos_closure
|
|
36
|
+
|
|
37
|
+
if var.loss is not None:
|
|
38
|
+
var.loss = var.loss.pow(2).sum()
|
|
39
|
+
|
|
40
|
+
if var.loss_approx is not None:
|
|
41
|
+
var.loss_approx = var.loss_approx.pow(2).sum()
|
|
42
|
+
|
|
43
|
+
return var
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GaussNewton(Module):
|
|
47
|
+
"""Gauss-newton method.
|
|
48
|
+
|
|
49
|
+
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
50
|
+
Please add the ``backward`` argument, it will always be False but it is required.
|
|
51
|
+
Gradients will be calculated via batched autograd within this module, you don't need to
|
|
52
|
+
implement the backward pass. Please see below for an example.
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
|
|
56
|
+
the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
60
|
+
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
61
|
+
|
|
62
|
+
Examples:
|
|
63
|
+
|
|
64
|
+
minimizing the rosenbrock function:
|
|
65
|
+
```python
|
|
66
|
+
def rosenbrock(X):
|
|
67
|
+
x1, x2 = X
|
|
68
|
+
return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
|
|
69
|
+
|
|
70
|
+
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
71
|
+
opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
|
|
72
|
+
|
|
73
|
+
# define the closure for line search
|
|
74
|
+
def closure(backward=True):
|
|
75
|
+
return rosenbrock(X)
|
|
76
|
+
|
|
77
|
+
# minimize
|
|
78
|
+
for iter in range(10):
|
|
79
|
+
loss = opt.step(closure)
|
|
80
|
+
print(f'{loss = }')
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
training a neural network with a matrix-free GN trust region:
|
|
84
|
+
```python
|
|
85
|
+
X = torch.randn(64, 20)
|
|
86
|
+
y = torch.randn(64, 10)
|
|
87
|
+
|
|
88
|
+
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
89
|
+
opt = tz.Modular(
|
|
90
|
+
model.parameters(),
|
|
91
|
+
tz.m.TrustCG(tz.m.GaussNewton()),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def closure(backward=True):
|
|
95
|
+
y_hat = model(X) # (64, 10)
|
|
96
|
+
return (y_hat - y).pow(2).mean(0) # (10, )
|
|
97
|
+
|
|
98
|
+
for i in range(100):
|
|
99
|
+
losses = opt.step(closure)
|
|
100
|
+
if i % 10 == 0:
|
|
101
|
+
print(f'{losses.mean() = }')
|
|
102
|
+
```
|
|
103
|
+
"""
|
|
104
|
+
def __init__(self, reg:float = 1e-8, batched:bool=True, ):
|
|
105
|
+
super().__init__(defaults=dict(batched=batched, reg=reg))
|
|
106
|
+
|
|
107
|
+
@torch.no_grad
|
|
108
|
+
def update(self, var):
|
|
109
|
+
params = var.params
|
|
110
|
+
batched = self.defaults['batched']
|
|
111
|
+
|
|
112
|
+
closure = var.closure
|
|
113
|
+
assert closure is not None
|
|
114
|
+
|
|
115
|
+
# gauss newton direction
|
|
116
|
+
with torch.enable_grad():
|
|
117
|
+
f = var.get_loss(backward=False) # n_out
|
|
118
|
+
assert isinstance(f, torch.Tensor)
|
|
119
|
+
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
120
|
+
|
|
121
|
+
var.loss = f.pow(2).sum()
|
|
122
|
+
|
|
123
|
+
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
|
|
124
|
+
Gtf = G.T @ f.detach() # (ndim)
|
|
125
|
+
self.global_state["Gtf"] = Gtf
|
|
126
|
+
var.grad = vec_to_tensors(Gtf, var.params)
|
|
127
|
+
|
|
128
|
+
# set closure to calculate sum of squares for line searches etc
|
|
129
|
+
if var.closure is not None:
|
|
130
|
+
def sos_closure(backward=True):
|
|
131
|
+
if backward:
|
|
132
|
+
var.zero_grad()
|
|
133
|
+
with torch.enable_grad():
|
|
134
|
+
loss = closure(False).pow(2).sum()
|
|
135
|
+
loss.backward()
|
|
136
|
+
return loss
|
|
137
|
+
|
|
138
|
+
loss = closure(False).pow(2).sum()
|
|
139
|
+
return loss
|
|
140
|
+
|
|
141
|
+
var.closure = sos_closure
|
|
142
|
+
|
|
143
|
+
@torch.no_grad
|
|
144
|
+
def apply(self, var):
|
|
145
|
+
reg = self.defaults['reg']
|
|
146
|
+
|
|
147
|
+
G = self.global_state['G']
|
|
148
|
+
Gtf = self.global_state['Gtf']
|
|
149
|
+
|
|
150
|
+
GtG = G.T @ G # (ndim, ndim)
|
|
151
|
+
if reg != 0:
|
|
152
|
+
GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
|
|
153
|
+
|
|
154
|
+
v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
|
|
155
|
+
|
|
156
|
+
var.update = vec_to_tensors(v, var.params)
|
|
157
|
+
return var
|
|
158
|
+
|
|
159
|
+
def get_H(self, var):
|
|
160
|
+
G = self.global_state['G']
|
|
161
|
+
return linear_operator.AtA(G)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .backtracking import
|
|
3
|
-
from .
|
|
1
|
+
from .adaptive import AdaptiveTracking
|
|
2
|
+
from .backtracking import AdaptiveBacktracking, Backtracking
|
|
3
|
+
from .line_search import LineSearchBase
|
|
4
4
|
from .scipy import ScipyMinimizeScalar
|
|
5
|
-
from .
|
|
5
|
+
from .strong_wolfe import StrongWolfe
|