torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy.optimize
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ....utils import TensorList
|
|
10
|
+
from ..wrapper import WrapperBase
|
|
11
|
+
|
|
12
|
+
Closure = Callable[[bool], Any]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ScipyRootOptimization(WrapperBase):
|
|
16
|
+
|
|
17
|
+
"""Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
21
|
+
method (str | None, optional): _description_. Defaults to None.
|
|
22
|
+
tol (float | None, optional): _description_. Defaults to None.
|
|
23
|
+
callback (_type_, optional): _description_. Defaults to None.
|
|
24
|
+
options (_type_, optional): _description_. Defaults to None.
|
|
25
|
+
jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
params,
|
|
30
|
+
method: Literal[
|
|
31
|
+
"hybr",
|
|
32
|
+
"lm",
|
|
33
|
+
"broyden1",
|
|
34
|
+
"broyden2",
|
|
35
|
+
"anderson",
|
|
36
|
+
"linearmixing",
|
|
37
|
+
"diagbroyden",
|
|
38
|
+
"excitingmixing",
|
|
39
|
+
"krylov",
|
|
40
|
+
"df-sane",
|
|
41
|
+
] = 'hybr',
|
|
42
|
+
tol: float | None = None,
|
|
43
|
+
callback = None,
|
|
44
|
+
options = None,
|
|
45
|
+
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
46
|
+
):
|
|
47
|
+
super().__init__(params, {})
|
|
48
|
+
self.method = method
|
|
49
|
+
self.tol = tol
|
|
50
|
+
self.callback = callback
|
|
51
|
+
self.options = options
|
|
52
|
+
|
|
53
|
+
self.jac = jac
|
|
54
|
+
if self.jac == 'autograd': self.jac = True
|
|
55
|
+
|
|
56
|
+
# those don't require jacobian
|
|
57
|
+
if self.method.lower() in ('broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing', 'krylov', 'df-sane'):
|
|
58
|
+
self.jac = None
|
|
59
|
+
|
|
60
|
+
def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
|
|
61
|
+
if self.jac:
|
|
62
|
+
f, g, H = self._f_g_H(x, params, closure)
|
|
63
|
+
return g, H
|
|
64
|
+
|
|
65
|
+
f, g = self._f_g(x, params, closure)
|
|
66
|
+
return g
|
|
67
|
+
|
|
68
|
+
@torch.no_grad
|
|
69
|
+
def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
70
|
+
params = TensorList(self._get_params())
|
|
71
|
+
x0 = params.to_vec().numpy(force=True)
|
|
72
|
+
|
|
73
|
+
res = scipy.optimize.root(
|
|
74
|
+
partial(self._objective, params = params, closure = closure),
|
|
75
|
+
x0 = x0,
|
|
76
|
+
method=self.method,
|
|
77
|
+
tol=self.tol,
|
|
78
|
+
callback=self.callback,
|
|
79
|
+
options=self.options,
|
|
80
|
+
jac = self.jac,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
|
|
84
|
+
return res.fun
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ScipyLeastSquaresOptimization(WrapperBase):
|
|
88
|
+
"""Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
92
|
+
method (str | None, optional): _description_. Defaults to None.
|
|
93
|
+
tol (float | None, optional): _description_. Defaults to None.
|
|
94
|
+
callback (_type_, optional): _description_. Defaults to None.
|
|
95
|
+
options (_type_, optional): _description_. Defaults to None.
|
|
96
|
+
jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
|
|
97
|
+
"""
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
params,
|
|
101
|
+
method='trf',
|
|
102
|
+
jac='autograd',
|
|
103
|
+
bounds=(-np.inf, np.inf),
|
|
104
|
+
ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
|
|
105
|
+
f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
|
|
106
|
+
jac_sparsity=None, max_nfev=None, verbose=0
|
|
107
|
+
):
|
|
108
|
+
super().__init__(params, {})
|
|
109
|
+
kwargs = locals().copy()
|
|
110
|
+
del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
|
|
111
|
+
self._kwargs = kwargs
|
|
112
|
+
|
|
113
|
+
self.jac = jac
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
|
|
117
|
+
f, g = self._f_g(x, params, closure)
|
|
118
|
+
return g
|
|
119
|
+
|
|
120
|
+
def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
|
|
121
|
+
f,g,H = self._f_g_H(x, params, closure)
|
|
122
|
+
return H
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
126
|
+
params = TensorList(self._get_params())
|
|
127
|
+
x0 = params.to_vec().numpy(force=True)
|
|
128
|
+
|
|
129
|
+
if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
|
|
130
|
+
else: jac = self.jac
|
|
131
|
+
|
|
132
|
+
res = scipy.optimize.least_squares(
|
|
133
|
+
partial(self._objective, params = params, closure = closure),
|
|
134
|
+
x0 = x0,
|
|
135
|
+
jac=jac, # type:ignore
|
|
136
|
+
**self._kwargs
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
|
|
140
|
+
return res.fun
|
|
141
|
+
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy.optimize
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ....utils import TensorList
|
|
10
|
+
from ..wrapper import WrapperBase
|
|
11
|
+
|
|
12
|
+
Closure = Callable[[bool], Any]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _use_jac_hess_hessp(method, jac, hess, use_hessp):
|
|
16
|
+
# those methods can't use hessp
|
|
17
|
+
if (method is None) or (method.lower() not in ("newton-cg", "trust-ncg", "trust-krylov", "trust-constr")):
|
|
18
|
+
use_hessp = False
|
|
19
|
+
|
|
20
|
+
# those use gradients
|
|
21
|
+
use_jac_autograd = (jac.lower() == 'autograd') and ((method is None) or (method.lower() in [
|
|
22
|
+
'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
|
|
23
|
+
'trust-ncg', 'trust-krylov', 'trust-exact', 'trust-constr',
|
|
24
|
+
]))
|
|
25
|
+
|
|
26
|
+
# those use hessian/ some of them can use hessp instead
|
|
27
|
+
use_hess_autograd = (isinstance(hess, str)) and (hess.lower() == 'autograd') and (method is not None) and (method.lower() in [
|
|
28
|
+
'newton-cg', 'dogleg', 'trust-ncg', 'trust-krylov', 'trust-exact'
|
|
29
|
+
])
|
|
30
|
+
|
|
31
|
+
# jac in scipy is '2-point', '3-point', 'cs', True or None.
|
|
32
|
+
if jac == 'autograd':
|
|
33
|
+
if use_jac_autograd: jac = True
|
|
34
|
+
else: jac = None
|
|
35
|
+
|
|
36
|
+
return jac, use_jac_autograd, use_hess_autograd, use_hessp
|
|
37
|
+
|
|
38
|
+
class ScipyMinimize(WrapperBase):
|
|
39
|
+
"""Use scipy.minimize.optimize as pytorch optimizer. Note that this performs full minimization on each step,
|
|
40
|
+
so usually you would want to perform a single step, although performing multiple steps will refine the
|
|
41
|
+
solution.
|
|
42
|
+
|
|
43
|
+
Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
|
|
44
|
+
for a detailed description of args.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
48
|
+
method (str | None, optional): type of solver.
|
|
49
|
+
If None, scipy will select one of BFGS, L-BFGS-B, SLSQP,
|
|
50
|
+
depending on whether or not the problem has constraints or bounds.
|
|
51
|
+
Defaults to None.
|
|
52
|
+
bounds (optional): bounds on variables. Defaults to None.
|
|
53
|
+
constraints (tuple, optional): constraints definition. Defaults to ().
|
|
54
|
+
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
55
|
+
callback (Callable | None, optional): A callable called after each iteration. Defaults to None.
|
|
56
|
+
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
57
|
+
jac (str, optional): Method for computing the gradient vector.
|
|
58
|
+
Only for CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
|
|
59
|
+
In addition to scipy options, this supports 'autograd', which uses pytorch autograd.
|
|
60
|
+
This setting is ignored for methods that don't require gradient. Defaults to 'autograd'.
|
|
61
|
+
hess (str, optional):
|
|
62
|
+
Method for computing the Hessian matrix.
|
|
63
|
+
Only for Newton-CG, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
|
|
64
|
+
This setting is ignored for methods that don't require hessian. Defaults to 'autograd'.
|
|
65
|
+
tikhonov (float, optional):
|
|
66
|
+
optional hessian regularizer value. Only has effect for methods that require hessian.
|
|
67
|
+
"""
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
params,
|
|
71
|
+
method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
|
|
72
|
+
'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
|
|
73
|
+
'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
|
|
74
|
+
'trust-krylov'] | str | None = None,
|
|
75
|
+
lb = None,
|
|
76
|
+
ub = None,
|
|
77
|
+
constraints = (),
|
|
78
|
+
tol: float | None = None,
|
|
79
|
+
callback = None,
|
|
80
|
+
options = None,
|
|
81
|
+
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
82
|
+
hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
|
|
83
|
+
use_hessp: bool = True,
|
|
84
|
+
):
|
|
85
|
+
defaults = dict(lb=lb, ub=ub)
|
|
86
|
+
super().__init__(params, defaults)
|
|
87
|
+
self.method = method
|
|
88
|
+
self.constraints = constraints
|
|
89
|
+
self.tol = tol
|
|
90
|
+
self.callback = callback
|
|
91
|
+
self.options = options
|
|
92
|
+
self.hess = hess
|
|
93
|
+
|
|
94
|
+
self.jac, self.use_jac_autograd, self.use_hess_autograd, self.use_hessp = _use_jac_hess_hessp(method, jac, hess, use_hessp)
|
|
95
|
+
|
|
96
|
+
def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
|
|
97
|
+
if self.use_jac_autograd:
|
|
98
|
+
f, g = self._f_g(x, params, closure)
|
|
99
|
+
if self.method is not None and self.method.lower() == 'slsqp': g = g.astype(np.float64) # slsqp requires float64
|
|
100
|
+
return f, g
|
|
101
|
+
|
|
102
|
+
return self._f(x, params, closure)
|
|
103
|
+
|
|
104
|
+
def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
|
|
105
|
+
f,g,H = self._f_g_H(x, params, closure)
|
|
106
|
+
return H
|
|
107
|
+
|
|
108
|
+
def _hessp(self, x: np.ndarray, p:np.ndarray, params: list[torch.Tensor], closure):
|
|
109
|
+
f,g,Hvp = self._f_g_Hvp(x, p, params, closure)
|
|
110
|
+
return Hvp
|
|
111
|
+
|
|
112
|
+
@torch.no_grad
|
|
113
|
+
def step(self, closure: Closure):# pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
114
|
+
params = TensorList(self._get_params())
|
|
115
|
+
x0 = params.to_vec().numpy(force=True)
|
|
116
|
+
bounds = self._get_bounds()
|
|
117
|
+
|
|
118
|
+
# determine hess argument
|
|
119
|
+
hess = self.hess
|
|
120
|
+
hessp = None
|
|
121
|
+
if hess == 'autograd':
|
|
122
|
+
if self.use_hess_autograd:
|
|
123
|
+
if self.use_hessp:
|
|
124
|
+
hessp = partial(self._hessp, params=params, closure=closure)
|
|
125
|
+
hess = None
|
|
126
|
+
else:
|
|
127
|
+
hess = partial(self._hess, params=params, closure=closure)
|
|
128
|
+
# hess = 'autograd' but method doesn't use hess
|
|
129
|
+
else:
|
|
130
|
+
hess = None
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
|
|
134
|
+
x0 = x0.astype(np.float64) # those methods error without this
|
|
135
|
+
|
|
136
|
+
res = scipy.optimize.minimize(
|
|
137
|
+
partial(self._objective, params = params, closure = closure),
|
|
138
|
+
x0 = x0,
|
|
139
|
+
method=self.method,
|
|
140
|
+
bounds=bounds,
|
|
141
|
+
constraints=self.constraints,
|
|
142
|
+
tol=self.tol,
|
|
143
|
+
callback=self.callback,
|
|
144
|
+
options=self.options,
|
|
145
|
+
jac = self.jac,
|
|
146
|
+
hess = hess,
|
|
147
|
+
hessp = hessp
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
|
|
151
|
+
return res.fun
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy.optimize
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ....utils import TensorList
|
|
10
|
+
from ..wrapper import WrapperBase
|
|
11
|
+
from .minimize import _use_jac_hess_hessp
|
|
12
|
+
|
|
13
|
+
Closure = Callable[[bool], Any]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ScipySHGO(WrapperBase):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
params,
|
|
20
|
+
lb: float,
|
|
21
|
+
ub: float,
|
|
22
|
+
constraints = None,
|
|
23
|
+
n: int = 100,
|
|
24
|
+
iters: int = 1,
|
|
25
|
+
callback = None,
|
|
26
|
+
options: dict | None = None,
|
|
27
|
+
sampling_method: str = 'simplicial',
|
|
28
|
+
minimizer_kwargs: dict | None = None,
|
|
29
|
+
method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
|
|
30
|
+
'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
|
|
31
|
+
'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
|
|
32
|
+
'trust-krylov'] | str = 'l-bfgs-b',
|
|
33
|
+
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
34
|
+
hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
|
|
35
|
+
use_hessp: bool = True,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(params, dict(lb=lb, ub=ub))
|
|
38
|
+
|
|
39
|
+
kwargs = locals().copy()
|
|
40
|
+
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__'], kwargs["options"]
|
|
41
|
+
del kwargs["method"], kwargs["jac"], kwargs["hess"], kwargs["use_hessp"], kwargs["minimizer_kwargs"]
|
|
42
|
+
self._kwargs = kwargs
|
|
43
|
+
self.minimizer_kwargs = minimizer_kwargs
|
|
44
|
+
self.options = options
|
|
45
|
+
self.method = method
|
|
46
|
+
self.hess = hess
|
|
47
|
+
|
|
48
|
+
self.jac, self.use_jac_autograd, self.use_hess_autograd, self.use_hessp = _use_jac_hess_hessp(method, jac, hess, use_hessp)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
|
|
52
|
+
if self.use_jac_autograd:
|
|
53
|
+
f, g = self._f_g(x, params, closure)
|
|
54
|
+
if self.method.lower() == 'slsqp': g = g.astype(np.float64) # slsqp requires float64
|
|
55
|
+
return f, g
|
|
56
|
+
|
|
57
|
+
return self._f(x, params, closure)
|
|
58
|
+
|
|
59
|
+
def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
|
|
60
|
+
f,g,H = self._f_g_H(x, params, closure)
|
|
61
|
+
return H
|
|
62
|
+
|
|
63
|
+
def _hessp(self, x: np.ndarray, p:np.ndarray, params: list[torch.Tensor], closure):
|
|
64
|
+
f,g,Hvp = self._f_g_Hvp(x, p, params, closure)
|
|
65
|
+
return Hvp
|
|
66
|
+
|
|
67
|
+
@torch.no_grad
|
|
68
|
+
def step(self, closure: Closure):
|
|
69
|
+
params = TensorList(self._get_params())
|
|
70
|
+
x0 = params.to_vec().numpy(force=True)
|
|
71
|
+
bounds = self._get_bounds()
|
|
72
|
+
assert bounds is not None
|
|
73
|
+
|
|
74
|
+
# determine hess argument
|
|
75
|
+
hess = self.hess
|
|
76
|
+
hessp = None
|
|
77
|
+
if hess == 'autograd':
|
|
78
|
+
if self.use_hess_autograd:
|
|
79
|
+
if self.use_hessp:
|
|
80
|
+
hessp = partial(self._hessp, params=params, closure=closure)
|
|
81
|
+
hess = None
|
|
82
|
+
else:
|
|
83
|
+
hess = partial(self._hess, params=params, closure=closure)
|
|
84
|
+
# hess = 'autograd' but method doesn't use hess
|
|
85
|
+
else:
|
|
86
|
+
hess = None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
if self.method.lower() in ('tnc', 'slsqp'):
|
|
90
|
+
x0 = x0.astype(np.float64) # those methods error without this
|
|
91
|
+
|
|
92
|
+
minimizer_kwargs = self.minimizer_kwargs.copy() if self.minimizer_kwargs is not None else {}
|
|
93
|
+
minimizer_kwargs.setdefault("method", self.method)
|
|
94
|
+
|
|
95
|
+
options = self.options.copy() if self.options is not None else {}
|
|
96
|
+
minimizer_kwargs.setdefault("jac", self.jac)
|
|
97
|
+
minimizer_kwargs.setdefault("hess", hess)
|
|
98
|
+
minimizer_kwargs.setdefault("hessp", hessp)
|
|
99
|
+
minimizer_kwargs.setdefault("bounds", bounds)
|
|
100
|
+
|
|
101
|
+
res = scipy.optimize.shgo(
|
|
102
|
+
partial(self._objective, params=params, closure=closure),
|
|
103
|
+
bounds=bounds,
|
|
104
|
+
minimizer_kwargs=minimizer_kwargs,
|
|
105
|
+
options=options,
|
|
106
|
+
**self._kwargs
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
|
|
110
|
+
return res.fun
|
|
111
|
+
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...utils import TensorList, tonumpy
|
|
8
|
+
from ...utils.derivatives import (
|
|
9
|
+
flatten_jacobian,
|
|
10
|
+
jacobian_and_hessian_mat_wrt,
|
|
11
|
+
jacobian_wrt,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WrapperBase(torch.optim.Optimizer):
|
|
16
|
+
def __init__(self, params, defaults):
|
|
17
|
+
super().__init__(params, defaults)
|
|
18
|
+
|
|
19
|
+
@torch.no_grad
|
|
20
|
+
def _f(self, x: np.ndarray, params: list[torch.Tensor], closure) -> float:
|
|
21
|
+
# set params to x
|
|
22
|
+
params = TensorList(params)
|
|
23
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
24
|
+
|
|
25
|
+
return float(closure(False))
|
|
26
|
+
|
|
27
|
+
@torch.no_grad
|
|
28
|
+
def _fs(self, x: np.ndarray, params: list[torch.Tensor], closure) -> np.ndarray:
|
|
29
|
+
# set params to x
|
|
30
|
+
params = TensorList(params)
|
|
31
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
32
|
+
|
|
33
|
+
return tonumpy(closure(False)).reshape(-1)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def _f_g(self, x: np.ndarray, params: list[torch.Tensor], closure) -> tuple[float, np.ndarray]:
|
|
38
|
+
# set params to x
|
|
39
|
+
params = TensorList(params)
|
|
40
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
41
|
+
|
|
42
|
+
# compute value and derivatives
|
|
43
|
+
with torch.enable_grad():
|
|
44
|
+
value = closure()
|
|
45
|
+
g = params.grad.fill_none(reference=params).to_vec()
|
|
46
|
+
return float(value), g.numpy(force=True)
|
|
47
|
+
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def _f_g_H(self, x: np.ndarray, params: list[torch.Tensor], closure) -> tuple[float, np.ndarray, np.ndarray]:
|
|
50
|
+
# set params to x
|
|
51
|
+
params = TensorList(params)
|
|
52
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
53
|
+
|
|
54
|
+
# compute value and derivatives
|
|
55
|
+
with torch.enable_grad():
|
|
56
|
+
value = closure(False)
|
|
57
|
+
g, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
|
|
58
|
+
return float(value), g.numpy(force=True), H.numpy(force=True)
|
|
59
|
+
|
|
60
|
+
@torch.no_grad
|
|
61
|
+
def _f_g_Hvp(self, x: np.ndarray, v: np.ndarray, params: list[torch.Tensor], closure) -> tuple[float, np.ndarray, np.ndarray]:
|
|
62
|
+
# set params to x
|
|
63
|
+
params = TensorList(params)
|
|
64
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
65
|
+
|
|
66
|
+
# compute value and derivatives
|
|
67
|
+
with torch.enable_grad():
|
|
68
|
+
value = closure(False)
|
|
69
|
+
grad = torch.autograd.grad(value, params, create_graph=True, allow_unused=True, materialize_grads=True)
|
|
70
|
+
flat_grad = torch.cat([i.reshape(-1) for i in grad])
|
|
71
|
+
Hvp = torch.autograd.grad(flat_grad, params, torch.as_tensor(v, device=flat_grad.device, dtype=flat_grad.dtype))[0]
|
|
72
|
+
|
|
73
|
+
return float(value), flat_grad.numpy(force=True), Hvp.numpy(force=True)
|
|
74
|
+
|
|
75
|
+
def _get_params(self) -> list[torch.Tensor]:
|
|
76
|
+
return [p for g in self.param_groups for p in g["params"]]
|
|
77
|
+
|
|
78
|
+
def _get_per_parameter_lb_ub(self):
|
|
79
|
+
# get per-parameter lb and ub
|
|
80
|
+
lb = []
|
|
81
|
+
ub = []
|
|
82
|
+
for group in self.param_groups:
|
|
83
|
+
lb.extend([group["lb"]] * len(group["params"]))
|
|
84
|
+
ub.extend([group["ub"]] * len(group["params"]))
|
|
85
|
+
|
|
86
|
+
return lb, ub
|
|
87
|
+
|
|
88
|
+
def _get_bounds(self):
|
|
89
|
+
|
|
90
|
+
# get per-parameter lb and ub
|
|
91
|
+
lb, ub = self._get_per_parameter_lb_ub()
|
|
92
|
+
if all(i is None for i in lb) and all(i is None for i in ub): return None
|
|
93
|
+
|
|
94
|
+
params = self._get_params()
|
|
95
|
+
bounds = []
|
|
96
|
+
for p, l, u in zip(params, lb, ub):
|
|
97
|
+
bounds.extend([(l, u)] * p.numel())
|
|
98
|
+
|
|
99
|
+
return bounds
|
|
100
|
+
|
|
101
|
+
def _get_lb_ub(self, ld:dict | None = None, ud: dict | None = None):
|
|
102
|
+
if ld is None: ld = {}
|
|
103
|
+
if ud is None: ud = {}
|
|
104
|
+
|
|
105
|
+
# get per-parameter lb and ub
|
|
106
|
+
lb, ub = self._get_per_parameter_lb_ub()
|
|
107
|
+
|
|
108
|
+
params = self._get_params()
|
|
109
|
+
lb_list = []
|
|
110
|
+
ub_list = []
|
|
111
|
+
for p, l, u in zip(params, lb, ub):
|
|
112
|
+
if l in ld: l = ld[l]
|
|
113
|
+
if u in ud: l = ud[u]
|
|
114
|
+
lb_list.extend([l] * p.numel())
|
|
115
|
+
ub_list.extend([u] * p.numel())
|
|
116
|
+
|
|
117
|
+
return lb_list, ub_list
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def step(self, closure) -> Any: # pyright:ignore[reportIncompatibleMethodOverride] # pylint:disable=signature-differs
|
|
121
|
+
...
|
torchzero/utils/__init__.py
CHANGED
|
@@ -1,33 +1,15 @@
|
|
|
1
1
|
from . import tensorlist as tl
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
)
|
|
9
|
-
from .numberlist import NumberList
|
|
10
|
-
from .optimizer import (
|
|
11
|
-
Init,
|
|
12
|
-
ListLike,
|
|
13
|
-
Optimizer,
|
|
14
|
-
ParamFilter,
|
|
15
|
-
get_group_vals,
|
|
16
|
-
get_params,
|
|
17
|
-
get_state_vals,
|
|
18
|
-
unpack_states,
|
|
19
|
-
)
|
|
20
|
-
from .params import (
|
|
21
|
-
Params,
|
|
22
|
-
_add_defaults_to_param_groups_,
|
|
23
|
-
_add_updates_grads_to_param_groups_,
|
|
24
|
-
_copy_param_groups,
|
|
25
|
-
_make_param_groups,
|
|
26
|
-
)
|
|
2
|
+
|
|
3
|
+
from .metrics import evaluate_metric
|
|
4
|
+
from .numberlist import NumberList , maybe_numberlist
|
|
5
|
+
from .optimizer import unpack_states
|
|
6
|
+
|
|
7
|
+
|
|
27
8
|
from .python_tools import (
|
|
28
9
|
flatten,
|
|
29
10
|
generic_eq,
|
|
30
11
|
generic_ne,
|
|
12
|
+
generic_is_none,
|
|
31
13
|
reduce_dim,
|
|
32
14
|
safe_dict_update_,
|
|
33
15
|
unpack_dicts,
|
torchzero/utils/compile.py
CHANGED
|
@@ -38,11 +38,11 @@ class _MaybeCompiledFunc:
|
|
|
38
38
|
_optional_compiler = _OptionalCompiler()
|
|
39
39
|
"""this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
|
|
40
40
|
|
|
41
|
-
def
|
|
41
|
+
def enable_compilation(enable: bool=True):
|
|
42
42
|
"""`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
|
|
43
43
|
_optional_compiler.enable = enable
|
|
44
44
|
|
|
45
|
-
def
|
|
45
|
+
def allow_compile(fn): return _optional_compiler.enable_compilation(fn)
|
|
46
46
|
|
|
47
47
|
def benchmark_compile_cuda(fn, n: int, **kwargs):
|
|
48
48
|
# warmup
|