torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -33,8 +33,45 @@ class DirectSearch(Optimizer):
|
|
|
33
33
|
solution.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
params
|
|
37
|
-
|
|
36
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
37
|
+
|
|
38
|
+
rho: Choice of the forcing function.
|
|
39
|
+
|
|
40
|
+
sketch_dim: Reduced dimension to generate polling directions in.
|
|
41
|
+
|
|
42
|
+
sketch_type: Sketching technique to be used.
|
|
43
|
+
|
|
44
|
+
maxevals: Maximum number of calls to f performed by the algorithm.
|
|
45
|
+
|
|
46
|
+
poll_type: Type of polling directions generated in the reduced spaces.
|
|
47
|
+
|
|
48
|
+
alpha0: Initial value for the stepsize parameter.
|
|
49
|
+
|
|
50
|
+
alpha_max: Maximum value for the stepsize parameter.
|
|
51
|
+
|
|
52
|
+
alpha_min: Minimum value for the stepsize parameter.
|
|
53
|
+
|
|
54
|
+
gamma_inc: Increase factor for the stepsize update.
|
|
55
|
+
|
|
56
|
+
gamma_dec: Decrease factor for the stepsize update.
|
|
57
|
+
|
|
58
|
+
verbose:
|
|
59
|
+
Boolean indicating whether information should be displayed during an algorithmic run.
|
|
60
|
+
|
|
61
|
+
print_freq:
|
|
62
|
+
Value indicating how frequently information should be displayed.
|
|
63
|
+
|
|
64
|
+
use_stochastic_three_points:
|
|
65
|
+
Boolean indicating whether the specific stochastic three points method should be used.
|
|
66
|
+
|
|
67
|
+
poll_scale_prob: Probability of scaling the polling directions.
|
|
68
|
+
|
|
69
|
+
poll_scale_factor: Factor used to scale the polling directions.
|
|
70
|
+
|
|
71
|
+
rho_uses_normd:
|
|
72
|
+
Boolean indicating whether the forcing function should account for the norm of the direction.
|
|
73
|
+
|
|
74
|
+
|
|
38
75
|
"""
|
|
39
76
|
def __init__(
|
|
40
77
|
self,
|
|
@@ -27,18 +27,25 @@ class FcmaesWrapper(Optimizer):
|
|
|
27
27
|
Note that this performs full minimization on each step, so only perform one step with this.
|
|
28
28
|
|
|
29
29
|
Args:
|
|
30
|
-
params
|
|
31
|
-
lb (float):
|
|
32
|
-
ub (float):
|
|
33
|
-
optimizer (fcmaes.optimizer.Optimizer | None, optional):
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
30
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
31
|
+
lb (float): lower bounds, this can also be specified in param_groups.
|
|
32
|
+
ub (float): upper bounds, this can also be specified in param_groups.
|
|
33
|
+
optimizer (fcmaes.optimizer.Optimizer | None, optional):
|
|
34
|
+
optimizer to use. Default is a sequence of differential evolution and CMA-ES.
|
|
35
|
+
max_evaluations (int | None, optional):
|
|
36
|
+
Forced termination of all optimization runs after `max_evaluations` function evaluations.
|
|
37
|
+
Only used if optimizer is undefined, otherwise this setting is defined in the optimizer. Defaults to 50000.
|
|
38
|
+
value_limit (float | None, optional): Upper limit for optimized function values to be stored. Defaults to np.inf.
|
|
39
|
+
num_retries (int | None, optional): Number of optimization retries. Defaults to 1.
|
|
40
|
+
popsize (int | None, optional):
|
|
41
|
+
CMA-ES population size used for all CMA-ES runs.
|
|
42
|
+
Not used for differential evolution.
|
|
43
|
+
Ignored if parameter optimizer is defined. Defaults to 31.
|
|
44
|
+
capacity (int | None, optional): capacity of the evaluation store.. Defaults to 500.
|
|
45
|
+
stop_fitness (float | None, optional):
|
|
46
|
+
Limit for fitness value. optimization runs terminate if this value is reached. Defaults to -np.inf.
|
|
47
|
+
statistic_num (int | None, optional):
|
|
48
|
+
if > 0 stores the progress of the optimization. Defines the size of this store. Defaults to 0.
|
|
42
49
|
"""
|
|
43
50
|
def __init__(
|
|
44
51
|
self,
|
|
@@ -49,7 +56,7 @@ class FcmaesWrapper(Optimizer):
|
|
|
49
56
|
max_evaluations: int | None = 50000,
|
|
50
57
|
value_limit: float | None = np.inf,
|
|
51
58
|
num_retries: int | None = 1,
|
|
52
|
-
workers: int = 1,
|
|
59
|
+
# workers: int = 1,
|
|
53
60
|
popsize: int | None = 31,
|
|
54
61
|
capacity: int | None = 500,
|
|
55
62
|
stop_fitness: float | None = -np.inf,
|
|
@@ -60,6 +67,7 @@ class FcmaesWrapper(Optimizer):
|
|
|
60
67
|
kwargs = locals().copy()
|
|
61
68
|
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
62
69
|
self._kwargs = kwargs
|
|
70
|
+
self._kwargs['workers'] = 1
|
|
63
71
|
|
|
64
72
|
def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
|
|
65
73
|
if self.raised: return np.inf
|
torchzero/optim/wrappers/mads.py
CHANGED
|
@@ -31,16 +31,15 @@ class MADS(Optimizer):
|
|
|
31
31
|
solution.
|
|
32
32
|
|
|
33
33
|
Args:
|
|
34
|
-
params
|
|
35
|
-
lb (float): lower bounds
|
|
36
|
-
ub (float): upper bounds
|
|
34
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
35
|
+
lb (float): lower bounds, this can also be specified in param_groups.
|
|
36
|
+
ub (float): upper bounds, this can also be specified in param_groups.
|
|
37
37
|
dp (float, optional): Initial poll size as percent of bounds. Defaults to 0.1.
|
|
38
38
|
dm (float, optional): Initial mesh size as percent of bounds. Defaults to 0.01.
|
|
39
|
-
dp_tol (
|
|
40
|
-
nitermax (
|
|
39
|
+
dp_tol (float, optional): Minimum poll size stopping criteria. Defaults to -float('inf').
|
|
40
|
+
nitermax (float, optional): Maximum objective function evaluations. Defaults to float('inf').
|
|
41
41
|
displog (bool, optional): whether to show log. Defaults to False.
|
|
42
42
|
savelog (bool, optional): whether to save log. Defaults to False.
|
|
43
|
-
|
|
44
43
|
"""
|
|
45
44
|
def __init__(
|
|
46
45
|
self,
|
|
@@ -29,6 +29,12 @@ class NevergradWrapper(Optimizer):
|
|
|
29
29
|
use certain rule for first 50% of the steps, and then switch to another rule.
|
|
30
30
|
This parameter doesn't actually limit the maximum number of steps!
|
|
31
31
|
But it doesn't have to be exact. Defaults to None.
|
|
32
|
+
lb (float | None, optional):
|
|
33
|
+
lower bounds, this can also be specified in param_groups. Bounds are optional, however
|
|
34
|
+
some nevergrad algorithms will raise an exception of bounds are not specified.
|
|
35
|
+
ub (float, optional):
|
|
36
|
+
upper bounds, this can also be specified in param_groups. Bounds are optional, however
|
|
37
|
+
some nevergrad algorithms will raise an exception of bounds are not specified.
|
|
32
38
|
mutable_sigma (bool, optional):
|
|
33
39
|
nevergrad parameter, sets whether the mutation standard deviation must mutate as well
|
|
34
40
|
(for mutation based algorithms). Defaults to False.
|
|
@@ -44,11 +50,20 @@ class NevergradWrapper(Optimizer):
|
|
|
44
50
|
params,
|
|
45
51
|
opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
|
|
46
52
|
budget: int | None = None,
|
|
47
|
-
mutable_sigma = False,
|
|
48
53
|
lb: float | None = None,
|
|
49
54
|
ub: float | None = None,
|
|
55
|
+
mutable_sigma = False,
|
|
50
56
|
use_init = True,
|
|
51
57
|
):
|
|
58
|
+
"""_summary_
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
params (_type_): _description_
|
|
62
|
+
opt_cls (type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]): _description_
|
|
63
|
+
budget (int | None, optional): _description_. Defaults to None.
|
|
64
|
+
mutable_sigma (bool, optional): _description_. Defaults to False.
|
|
65
|
+
use_init (bool, optional): _description_. Defaults to True.
|
|
66
|
+
"""
|
|
52
67
|
defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
|
|
53
68
|
super().__init__(params, defaults)
|
|
54
69
|
self.opt_cls = opt_cls
|
|
@@ -23,7 +23,7 @@ class OptunaSampler(Optimizer):
|
|
|
23
23
|
Note - optuna is surprisingly scalable to large number of parameters (up to 10,000), despite literally requiring a for-loop because it only supports scalars. Default TPESampler is good for BBO. Maybe not for NNs...
|
|
24
24
|
|
|
25
25
|
Args:
|
|
26
|
-
params
|
|
26
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
27
27
|
lb (float): lower bounds.
|
|
28
28
|
ub (float): upper bounds.
|
|
29
29
|
sampler (optuna.samplers.BaseSampler | type[optuna.samplers.BaseSampler] | None, optional): sampler. Defaults to None.
|
|
@@ -139,9 +139,11 @@ class ScipyMinimize(Optimizer):
|
|
|
139
139
|
|
|
140
140
|
# make bounds
|
|
141
141
|
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
142
|
-
bounds =
|
|
143
|
-
|
|
144
|
-
bounds
|
|
142
|
+
bounds = None
|
|
143
|
+
if any(b is not None for b in lb) or any(b is not None for b in ub):
|
|
144
|
+
bounds = []
|
|
145
|
+
for p, l, u in zip(params, lb, ub):
|
|
146
|
+
bounds.extend([(l, u)] * p.numel())
|
|
145
147
|
|
|
146
148
|
if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
|
|
147
149
|
x0 = x0.astype(np.float64) # those methods error without this
|
torchzero/utils/__init__.py
CHANGED
|
@@ -18,6 +18,6 @@ from .params import (
|
|
|
18
18
|
_copy_param_groups,
|
|
19
19
|
_make_param_groups,
|
|
20
20
|
)
|
|
21
|
-
from .python_tools import flatten, generic_eq, reduce_dim, unpack_dicts
|
|
22
|
-
from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like
|
|
21
|
+
from .python_tools import flatten, generic_eq, generic_ne, reduce_dim, unpack_dicts
|
|
22
|
+
from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like, generic_finfo_eps
|
|
23
23
|
from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_
|
torchzero/utils/derivatives.py
CHANGED
|
@@ -158,7 +158,7 @@ def hessian_mat(
|
|
|
158
158
|
method="func",
|
|
159
159
|
vectorize=False,
|
|
160
160
|
outer_jacobian_strategy="reverse-mode",
|
|
161
|
-
):
|
|
161
|
+
) -> torch.Tensor:
|
|
162
162
|
"""
|
|
163
163
|
returns hessian matrix for parameters (as if they were flattened and concatenated into a vector).
|
|
164
164
|
|
|
@@ -190,7 +190,7 @@ def hessian_mat(
|
|
|
190
190
|
return loss
|
|
191
191
|
|
|
192
192
|
if method == 'func':
|
|
193
|
-
return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph))
|
|
193
|
+
return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph)) # pyright:ignore[reportReturnType]
|
|
194
194
|
|
|
195
195
|
if method == 'autograd.functional':
|
|
196
196
|
return torch.autograd.functional.hessian(
|
|
@@ -199,7 +199,7 @@ def hessian_mat(
|
|
|
199
199
|
create_graph=create_graph,
|
|
200
200
|
vectorize=vectorize,
|
|
201
201
|
outer_jacobian_strategy=outer_jacobian_strategy,
|
|
202
|
-
)
|
|
202
|
+
) # pyright:ignore[reportReturnType]
|
|
203
203
|
raise ValueError(method)
|
|
204
204
|
|
|
205
205
|
def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -2,4 +2,4 @@ from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix
|
|
|
2
2
|
from .orthogonalize import gram_schmidt
|
|
3
3
|
from .qr import qr_householder
|
|
4
4
|
from .svd import randomized_svd
|
|
5
|
-
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
|
|
5
|
+
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve, steihaug_toint_cg
|
torchzero/utils/linalg/solve.py
CHANGED
|
@@ -1,12 +1,41 @@
|
|
|
1
|
+
# pyright: reportArgumentType=false
|
|
1
2
|
from collections.abc import Callable
|
|
2
|
-
from typing import overload
|
|
3
|
+
from typing import Any, overload
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
5
|
-
from .. import
|
|
7
|
+
from .. import (
|
|
8
|
+
TensorList,
|
|
9
|
+
generic_eq,
|
|
10
|
+
generic_finfo_eps,
|
|
11
|
+
generic_numel,
|
|
12
|
+
generic_randn_like,
|
|
13
|
+
generic_vector_norm,
|
|
14
|
+
generic_zeros_like,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _make_A_mm_reg(A_mm: Callable | torch.Tensor, reg):
|
|
19
|
+
if callable(A_mm):
|
|
20
|
+
def A_mm_reg(x): # A_mm with regularization
|
|
21
|
+
Ax = A_mm(x)
|
|
22
|
+
if not generic_eq(reg, 0): Ax += x*reg
|
|
23
|
+
return Ax
|
|
24
|
+
return A_mm_reg
|
|
25
|
+
|
|
26
|
+
if not isinstance(A_mm, torch.Tensor): raise TypeError(type(A_mm))
|
|
27
|
+
|
|
28
|
+
def Ax_reg(x): # A_mm with regularization
|
|
29
|
+
if A_mm.ndim == 1: Ax = A_mm * x
|
|
30
|
+
else: Ax = A_mm @ x
|
|
31
|
+
if reg != 0: Ax += x*reg
|
|
32
|
+
return Ax
|
|
33
|
+
return Ax_reg
|
|
34
|
+
|
|
6
35
|
|
|
7
36
|
@overload
|
|
8
37
|
def cg(
|
|
9
|
-
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
38
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
10
39
|
b: torch.Tensor,
|
|
11
40
|
x0_: torch.Tensor | None = None,
|
|
12
41
|
tol: float | None = 1e-4,
|
|
@@ -24,17 +53,17 @@ def cg(
|
|
|
24
53
|
) -> TensorList: ...
|
|
25
54
|
|
|
26
55
|
def cg(
|
|
27
|
-
A_mm: Callable,
|
|
56
|
+
A_mm: Callable | torch.Tensor,
|
|
28
57
|
b: torch.Tensor | TensorList,
|
|
29
58
|
x0_: torch.Tensor | TensorList | None = None,
|
|
30
59
|
tol: float | None = 1e-4,
|
|
31
60
|
maxiter: int | None = None,
|
|
32
61
|
reg: float | list[float] | tuple[float] = 0,
|
|
33
62
|
):
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
63
|
+
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
64
|
+
eps = generic_finfo_eps(b)
|
|
65
|
+
|
|
66
|
+
if tol is None: tol = eps
|
|
38
67
|
|
|
39
68
|
if maxiter is None: maxiter = generic_numel(b)
|
|
40
69
|
if x0_ is None: x0_ = generic_zeros_like(b)
|
|
@@ -44,9 +73,10 @@ def cg(
|
|
|
44
73
|
p = residual.clone() # search direction
|
|
45
74
|
r_norm = generic_vector_norm(residual)
|
|
46
75
|
init_norm = r_norm
|
|
47
|
-
if
|
|
76
|
+
if r_norm < tol: return x
|
|
48
77
|
k = 0
|
|
49
78
|
|
|
79
|
+
|
|
50
80
|
while True:
|
|
51
81
|
Ap = A_mm_reg(p)
|
|
52
82
|
step_size = (r_norm**2) / p.dot(Ap)
|
|
@@ -55,7 +85,7 @@ def cg(
|
|
|
55
85
|
new_r_norm = generic_vector_norm(residual)
|
|
56
86
|
|
|
57
87
|
k += 1
|
|
58
|
-
if
|
|
88
|
+
if new_r_norm <= tol * init_norm: return x
|
|
59
89
|
if k >= maxiter: return x
|
|
60
90
|
|
|
61
91
|
beta = (new_r_norm**2) / (r_norm**2)
|
|
@@ -131,6 +161,8 @@ def nystrom_pcg(
|
|
|
131
161
|
generator=generator,
|
|
132
162
|
)
|
|
133
163
|
lambd += reg
|
|
164
|
+
eps = torch.finfo(b.dtype).eps ** 2
|
|
165
|
+
if tol is None: tol = eps
|
|
134
166
|
|
|
135
167
|
def A_mm_reg(x): # A_mm with regularization
|
|
136
168
|
Ax = A_mm(x)
|
|
@@ -150,7 +182,7 @@ def nystrom_pcg(
|
|
|
150
182
|
p = z.clone() # search direction
|
|
151
183
|
|
|
152
184
|
init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
|
|
153
|
-
if
|
|
185
|
+
if init_norm < tol: return x
|
|
154
186
|
k = 0
|
|
155
187
|
while True:
|
|
156
188
|
Ap = A_mm_reg(p)
|
|
@@ -160,10 +192,217 @@ def nystrom_pcg(
|
|
|
160
192
|
residual -= step_size * Ap
|
|
161
193
|
|
|
162
194
|
k += 1
|
|
163
|
-
if
|
|
195
|
+
if torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
|
|
164
196
|
if k >= maxiter: return x
|
|
165
197
|
|
|
166
198
|
z = P_inv @ residual
|
|
167
199
|
beta = residual.dot(z) / rz
|
|
168
200
|
p = z + p*beta
|
|
169
201
|
|
|
202
|
+
|
|
203
|
+
def _safe_clip(x: torch.Tensor):
|
|
204
|
+
"""makes sure scalar tensor x is not smaller than epsilon"""
|
|
205
|
+
assert x.numel() == 1, x.shape
|
|
206
|
+
eps = torch.finfo(x.dtype).eps
|
|
207
|
+
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
208
|
+
return x
|
|
209
|
+
|
|
210
|
+
def _trust_tau(x,d,trust_region):
|
|
211
|
+
xx = x.dot(x)
|
|
212
|
+
xd = x.dot(d)
|
|
213
|
+
dd = _safe_clip(d.dot(d))
|
|
214
|
+
|
|
215
|
+
rad = (xd**2 - dd * (xx - trust_region**2)).clip(min=0).sqrt()
|
|
216
|
+
tau = (-xd + rad) / dd
|
|
217
|
+
|
|
218
|
+
return x + tau * d
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@overload
|
|
222
|
+
def steihaug_toint_cg(
|
|
223
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
224
|
+
b: torch.Tensor,
|
|
225
|
+
trust_region: float,
|
|
226
|
+
x0: torch.Tensor | None = None,
|
|
227
|
+
tol: float | None = 1e-4,
|
|
228
|
+
maxiter: int | None = None,
|
|
229
|
+
reg: float = 0,
|
|
230
|
+
) -> torch.Tensor: ...
|
|
231
|
+
@overload
|
|
232
|
+
def steihaug_toint_cg(
|
|
233
|
+
A_mm: Callable[[TensorList], TensorList],
|
|
234
|
+
b: TensorList,
|
|
235
|
+
trust_region: float,
|
|
236
|
+
x0: TensorList | None = None,
|
|
237
|
+
tol: float | None = 1e-4,
|
|
238
|
+
maxiter: int | None = None,
|
|
239
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
240
|
+
) -> TensorList: ...
|
|
241
|
+
def steihaug_toint_cg(
|
|
242
|
+
A_mm: Callable | torch.Tensor,
|
|
243
|
+
b: torch.Tensor | TensorList,
|
|
244
|
+
trust_region: float,
|
|
245
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
246
|
+
tol: float | None = 1e-4,
|
|
247
|
+
maxiter: int | None = None,
|
|
248
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
249
|
+
):
|
|
250
|
+
"""
|
|
251
|
+
Solution is bounded to have L2 norm no larger than :code:`trust_region`. If solution exceeds :code:`trust_region`, CG is terminated early, so it is also faster.
|
|
252
|
+
"""
|
|
253
|
+
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
254
|
+
|
|
255
|
+
x = x0
|
|
256
|
+
if x is None: x = generic_zeros_like(b)
|
|
257
|
+
r = b
|
|
258
|
+
d = r.clone()
|
|
259
|
+
|
|
260
|
+
eps = generic_finfo_eps(b)**2
|
|
261
|
+
if tol is None: tol = eps
|
|
262
|
+
|
|
263
|
+
if generic_vector_norm(r) < tol:
|
|
264
|
+
return x
|
|
265
|
+
|
|
266
|
+
if maxiter is None:
|
|
267
|
+
maxiter = generic_numel(b)
|
|
268
|
+
|
|
269
|
+
for _ in range(maxiter):
|
|
270
|
+
Ad = A_mm_reg(d)
|
|
271
|
+
|
|
272
|
+
d_Ad = d.dot(Ad)
|
|
273
|
+
if d_Ad <= eps:
|
|
274
|
+
return _trust_tau(x, d, trust_region)
|
|
275
|
+
|
|
276
|
+
alpha = r.dot(r) / d_Ad
|
|
277
|
+
p_next = x + alpha * d
|
|
278
|
+
|
|
279
|
+
# check if the step exceeds the trust-region boundary
|
|
280
|
+
if generic_vector_norm(p_next) >= trust_region:
|
|
281
|
+
return _trust_tau(x, d, trust_region)
|
|
282
|
+
|
|
283
|
+
# update step, residual and direction
|
|
284
|
+
x = p_next
|
|
285
|
+
r_next = r - alpha * Ad
|
|
286
|
+
|
|
287
|
+
if generic_vector_norm(r_next) < tol:
|
|
288
|
+
return x
|
|
289
|
+
|
|
290
|
+
beta = r_next.dot(r_next) / r.dot(r)
|
|
291
|
+
d = r_next + beta * d
|
|
292
|
+
r = r_next
|
|
293
|
+
|
|
294
|
+
return x
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
|
|
299
|
+
@overload
|
|
300
|
+
def minres(
|
|
301
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
302
|
+
b: torch.Tensor,
|
|
303
|
+
x0: torch.Tensor | None = None,
|
|
304
|
+
tol: float | None = 1e-4,
|
|
305
|
+
maxiter: int | None = None,
|
|
306
|
+
reg: float = 0,
|
|
307
|
+
npc_terminate: bool=True,
|
|
308
|
+
trust_region: float | None = None,
|
|
309
|
+
) -> torch.Tensor: ...
|
|
310
|
+
@overload
|
|
311
|
+
def minres(
|
|
312
|
+
A_mm: Callable[[TensorList], TensorList],
|
|
313
|
+
b: TensorList,
|
|
314
|
+
x0: TensorList | None = None,
|
|
315
|
+
tol: float | None = 1e-4,
|
|
316
|
+
maxiter: int | None = None,
|
|
317
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
318
|
+
npc_terminate: bool=True,
|
|
319
|
+
trust_region: float | None = None,
|
|
320
|
+
) -> TensorList: ...
|
|
321
|
+
def minres(
|
|
322
|
+
A_mm,
|
|
323
|
+
b,
|
|
324
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
325
|
+
tol: float | None = 1e-4,
|
|
326
|
+
maxiter: int | None = None,
|
|
327
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
328
|
+
npc_terminate: bool=True,
|
|
329
|
+
trust_region: float | None = None,
|
|
330
|
+
):
|
|
331
|
+
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
332
|
+
eps = generic_finfo_eps(b)
|
|
333
|
+
if tol is None: tol = eps**2
|
|
334
|
+
|
|
335
|
+
if maxiter is None: maxiter = generic_numel(b)
|
|
336
|
+
if x0 is None:
|
|
337
|
+
R = b
|
|
338
|
+
x0 = generic_zeros_like(b)
|
|
339
|
+
else:
|
|
340
|
+
R = b - A_mm_reg(x0)
|
|
341
|
+
|
|
342
|
+
X: Any = x0
|
|
343
|
+
beta = b_norm = generic_vector_norm(b)
|
|
344
|
+
if b_norm < eps**2:
|
|
345
|
+
return generic_zeros_like(b)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
V = b / beta
|
|
349
|
+
V_prev = generic_zeros_like(b)
|
|
350
|
+
D = generic_zeros_like(b)
|
|
351
|
+
D_prev = generic_zeros_like(b)
|
|
352
|
+
|
|
353
|
+
c = -1
|
|
354
|
+
phi = tau = beta
|
|
355
|
+
s = delta1 = e = 0
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
for _ in range(maxiter):
|
|
359
|
+
|
|
360
|
+
P = A_mm_reg(V)
|
|
361
|
+
alpha = V.dot(P)
|
|
362
|
+
P -= beta*V_prev
|
|
363
|
+
P -= alpha*V
|
|
364
|
+
beta = generic_vector_norm(P)
|
|
365
|
+
|
|
366
|
+
delta2 = c*delta1 + s*alpha
|
|
367
|
+
gamma1 = s*delta1 - c*alpha
|
|
368
|
+
e_next = s*beta
|
|
369
|
+
delta1 = -c*beta
|
|
370
|
+
|
|
371
|
+
cgamma1 = c*gamma1
|
|
372
|
+
if trust_region is not None and cgamma1 >= 0:
|
|
373
|
+
if npc_terminate: return _trust_tau(X, R, trust_region)
|
|
374
|
+
return _trust_tau(X, D, trust_region)
|
|
375
|
+
|
|
376
|
+
if npc_terminate and cgamma1 >= 0:
|
|
377
|
+
return R
|
|
378
|
+
|
|
379
|
+
gamma2 = (gamma1**2 + beta**2)**(1/2)
|
|
380
|
+
|
|
381
|
+
if abs(gamma2) <= eps: # singular system
|
|
382
|
+
# c=0; s=1; tau=0
|
|
383
|
+
if trust_region is None: return X
|
|
384
|
+
return _trust_tau(X, D, trust_region)
|
|
385
|
+
|
|
386
|
+
c = gamma1 / gamma2
|
|
387
|
+
s = beta/gamma2
|
|
388
|
+
tau = c*phi
|
|
389
|
+
phi = s*phi
|
|
390
|
+
|
|
391
|
+
D_prev = D
|
|
392
|
+
D = (V - delta2*D - e*D_prev) / gamma2
|
|
393
|
+
e = e_next
|
|
394
|
+
X = X + tau*D
|
|
395
|
+
|
|
396
|
+
if trust_region is not None:
|
|
397
|
+
if generic_vector_norm(X) > trust_region:
|
|
398
|
+
return _trust_tau(X, D, trust_region)
|
|
399
|
+
|
|
400
|
+
if (abs(beta) < eps) or (phi / b_norm <= tol):
|
|
401
|
+
# R = zeros(R)
|
|
402
|
+
return X
|
|
403
|
+
|
|
404
|
+
V_prev = V
|
|
405
|
+
V = P/beta
|
|
406
|
+
R = s**2*R - phi*c*V
|
|
407
|
+
|
|
408
|
+
return X
|
torchzero/utils/numberlist.py
CHANGED
|
@@ -129,4 +129,6 @@ class NumberList(list[int | float | Any]):
|
|
|
129
129
|
return self.__class__(fn(i, *args, **kwargs) for i in self)
|
|
130
130
|
|
|
131
131
|
def clamp(self, min=None, max=None):
|
|
132
|
+
return self.zipmap_args(_clamp, min, max)
|
|
133
|
+
def clip(self, min=None, max=None):
|
|
132
134
|
return self.zipmap_args(_clamp, min, max)
|
torchzero/utils/python_tools.py
CHANGED
|
@@ -31,6 +31,16 @@ def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable
|
|
|
31
31
|
return all(i==y for i in x)
|
|
32
32
|
return all(i==j for i,j in zip(x,y))
|
|
33
33
|
|
|
34
|
+
def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
|
|
35
|
+
"""generic not equals function that supports scalars and lists of numbers. Faster than not generic_eq"""
|
|
36
|
+
if isinstance(x, (int,float)):
|
|
37
|
+
if isinstance(y, (int,float)): return x!=y
|
|
38
|
+
return any(i!=x for i in y)
|
|
39
|
+
if isinstance(y, (int,float)):
|
|
40
|
+
return any(i!=y for i in x)
|
|
41
|
+
return any(i!=j for i,j in zip(x,y))
|
|
42
|
+
|
|
43
|
+
|
|
34
44
|
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
35
45
|
"""If `other` is list/tuple, applies `fn` to self zipped with `other`.
|
|
36
46
|
Otherwise applies `fn` to this sequence and `other`.
|