torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,572 +0,0 @@
|
|
|
1
|
-
from collections import abc
|
|
2
|
-
from collections.abc import Callable
|
|
3
|
-
from functools import partial
|
|
4
|
-
from typing import Any, Literal
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
import scipy.optimize
|
|
10
|
-
|
|
11
|
-
from ...utils import Optimizer, TensorList
|
|
12
|
-
from ...utils.derivatives import (
|
|
13
|
-
flatten_jacobian,
|
|
14
|
-
jacobian_and_hessian_mat_wrt,
|
|
15
|
-
jacobian_wrt,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def _ensure_float(x) -> float:
|
|
20
|
-
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
21
|
-
if isinstance(x, np.ndarray): return float(x.item())
|
|
22
|
-
return float(x)
|
|
23
|
-
|
|
24
|
-
def _ensure_numpy(x):
|
|
25
|
-
if isinstance(x, torch.Tensor): return x.detach().cpu()
|
|
26
|
-
if isinstance(x, np.ndarray): return x
|
|
27
|
-
return np.array(x)
|
|
28
|
-
|
|
29
|
-
Closure = Callable[[bool], Any]
|
|
30
|
-
|
|
31
|
-
class ScipyMinimize(Optimizer):
|
|
32
|
-
"""Use scipy.minimize.optimize as pytorch optimizer. Note that this performs full minimization on each step,
|
|
33
|
-
so usually you would want to perform a single step, although performing multiple steps will refine the
|
|
34
|
-
solution.
|
|
35
|
-
|
|
36
|
-
Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
|
|
37
|
-
for a detailed description of args.
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
41
|
-
method (str | None, optional): type of solver.
|
|
42
|
-
If None, scipy will select one of BFGS, L-BFGS-B, SLSQP,
|
|
43
|
-
depending on whether or not the problem has constraints or bounds.
|
|
44
|
-
Defaults to None.
|
|
45
|
-
bounds (optional): bounds on variables. Defaults to None.
|
|
46
|
-
constraints (tuple, optional): constraints definition. Defaults to ().
|
|
47
|
-
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
48
|
-
callback (Callable | None, optional): A callable called after each iteration. Defaults to None.
|
|
49
|
-
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
50
|
-
jac (str, optional): Method for computing the gradient vector.
|
|
51
|
-
Only for CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
|
|
52
|
-
In addition to scipy options, this supports 'autograd', which uses pytorch autograd.
|
|
53
|
-
This setting is ignored for methods that don't require gradient. Defaults to 'autograd'.
|
|
54
|
-
hess (str, optional):
|
|
55
|
-
Method for computing the Hessian matrix.
|
|
56
|
-
Only for Newton-CG, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
|
|
57
|
-
This setting is ignored for methods that don't require hessian. Defaults to 'autograd'.
|
|
58
|
-
tikhonov (float, optional):
|
|
59
|
-
optional hessian regularizer value. Only has effect for methods that require hessian.
|
|
60
|
-
"""
|
|
61
|
-
def __init__(
|
|
62
|
-
self,
|
|
63
|
-
params,
|
|
64
|
-
method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
|
|
65
|
-
'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
|
|
66
|
-
'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
|
|
67
|
-
'trust-krylov'] | str | None = None,
|
|
68
|
-
lb = None,
|
|
69
|
-
ub = None,
|
|
70
|
-
constraints = (),
|
|
71
|
-
tol: float | None = None,
|
|
72
|
-
callback = None,
|
|
73
|
-
options = None,
|
|
74
|
-
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
75
|
-
hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
|
|
76
|
-
):
|
|
77
|
-
defaults = dict(lb=lb, ub=ub)
|
|
78
|
-
super().__init__(params, defaults)
|
|
79
|
-
self.method = method
|
|
80
|
-
self.constraints = constraints
|
|
81
|
-
self.tol = tol
|
|
82
|
-
self.callback = callback
|
|
83
|
-
self.options = options
|
|
84
|
-
|
|
85
|
-
self.jac = jac
|
|
86
|
-
self.hess = hess
|
|
87
|
-
|
|
88
|
-
self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
|
|
89
|
-
'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
|
|
90
|
-
'trust-ncg', 'trust-krylov', 'trust-exact', 'trust-constr',
|
|
91
|
-
])
|
|
92
|
-
self.use_hess_autograd = isinstance(hess, str) and hess.lower() == 'autograd' and method is not None and method.lower() in [
|
|
93
|
-
'newton-cg', 'dogleg', 'trust-ncg', 'trust-krylov', 'trust-exact'
|
|
94
|
-
]
|
|
95
|
-
|
|
96
|
-
# jac in scipy is '2-point', '3-point', 'cs', True or None.
|
|
97
|
-
if self.jac == 'autograd':
|
|
98
|
-
if self.use_jac_autograd: self.jac = True
|
|
99
|
-
else: self.jac = None
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def _hess(self, x: np.ndarray, params: TensorList, closure):
|
|
103
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
104
|
-
with torch.enable_grad():
|
|
105
|
-
value = closure(False)
|
|
106
|
-
_, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
|
|
107
|
-
return H.numpy(force=True)
|
|
108
|
-
|
|
109
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
110
|
-
# set params to x
|
|
111
|
-
params.from_vec_(torch.from_numpy(x).to(params[0], copy=False))
|
|
112
|
-
|
|
113
|
-
# return value and maybe gradients
|
|
114
|
-
if self.use_jac_autograd:
|
|
115
|
-
with torch.enable_grad(): value = _ensure_float(closure())
|
|
116
|
-
grad = params.ensure_grad_().grad.to_vec().numpy(force=True)
|
|
117
|
-
# slsqp requires float64
|
|
118
|
-
if self.method.lower() == 'slsqp': grad = grad.astype(np.float64)
|
|
119
|
-
return value, grad
|
|
120
|
-
return _ensure_float(closure(False))
|
|
121
|
-
|
|
122
|
-
@torch.no_grad
|
|
123
|
-
def step(self, closure: Closure):# pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
124
|
-
params = self.get_params()
|
|
125
|
-
|
|
126
|
-
# determine hess argument
|
|
127
|
-
if self.hess == 'autograd':
|
|
128
|
-
if self.use_hess_autograd: hess = partial(self._hess, params = params, closure = closure)
|
|
129
|
-
else: hess = None
|
|
130
|
-
else: hess = self.hess
|
|
131
|
-
|
|
132
|
-
x0 = params.to_vec().numpy(force=True)
|
|
133
|
-
|
|
134
|
-
# make bounds
|
|
135
|
-
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
136
|
-
bounds = None
|
|
137
|
-
if any(b is not None for b in lb) or any(b is not None for b in ub):
|
|
138
|
-
bounds = []
|
|
139
|
-
for p, l, u in zip(params, lb, ub):
|
|
140
|
-
bounds.extend([(l, u)] * p.numel())
|
|
141
|
-
|
|
142
|
-
if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
|
|
143
|
-
x0 = x0.astype(np.float64) # those methods error without this
|
|
144
|
-
|
|
145
|
-
res = scipy.optimize.minimize(
|
|
146
|
-
partial(self._objective, params = params, closure = closure),
|
|
147
|
-
x0 = x0,
|
|
148
|
-
method=self.method,
|
|
149
|
-
bounds=bounds,
|
|
150
|
-
constraints=self.constraints,
|
|
151
|
-
tol=self.tol,
|
|
152
|
-
callback=self.callback,
|
|
153
|
-
options=self.options,
|
|
154
|
-
jac = self.jac,
|
|
155
|
-
hess = hess,
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
159
|
-
return res.fun
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
class ScipyRootOptimization(Optimizer):
|
|
164
|
-
"""Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
|
|
165
|
-
|
|
166
|
-
Args:
|
|
167
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
168
|
-
method (str | None, optional): _description_. Defaults to None.
|
|
169
|
-
tol (float | None, optional): _description_. Defaults to None.
|
|
170
|
-
callback (_type_, optional): _description_. Defaults to None.
|
|
171
|
-
options (_type_, optional): _description_. Defaults to None.
|
|
172
|
-
jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
|
|
173
|
-
"""
|
|
174
|
-
def __init__(
|
|
175
|
-
self,
|
|
176
|
-
params,
|
|
177
|
-
method: Literal[
|
|
178
|
-
"hybr",
|
|
179
|
-
"lm",
|
|
180
|
-
"broyden1",
|
|
181
|
-
"broyden2",
|
|
182
|
-
"anderson",
|
|
183
|
-
"linearmixing",
|
|
184
|
-
"diagbroyden",
|
|
185
|
-
"excitingmixing",
|
|
186
|
-
"krylov",
|
|
187
|
-
"df-sane",
|
|
188
|
-
] = 'hybr',
|
|
189
|
-
tol: float | None = None,
|
|
190
|
-
callback = None,
|
|
191
|
-
options = None,
|
|
192
|
-
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
193
|
-
):
|
|
194
|
-
super().__init__(params, {})
|
|
195
|
-
self.method = method
|
|
196
|
-
self.tol = tol
|
|
197
|
-
self.callback = callback
|
|
198
|
-
self.options = options
|
|
199
|
-
|
|
200
|
-
self.jac = jac
|
|
201
|
-
if self.jac == 'autograd': self.jac = True
|
|
202
|
-
|
|
203
|
-
# those don't require jacobian
|
|
204
|
-
if self.method.lower() in ('broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing', 'krylov', 'df-sane'):
|
|
205
|
-
self.jac = None
|
|
206
|
-
|
|
207
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
208
|
-
# set params to x
|
|
209
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
210
|
-
|
|
211
|
-
# return gradients and maybe hessian
|
|
212
|
-
if self.jac:
|
|
213
|
-
with torch.enable_grad():
|
|
214
|
-
self.value = closure(False)
|
|
215
|
-
if not isinstance(self.value, torch.Tensor):
|
|
216
|
-
raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(self.value)}")
|
|
217
|
-
g, H = jacobian_and_hessian_mat_wrt([self.value], wrt=params)
|
|
218
|
-
return g.detach().cpu().numpy(), H.detach().cpu().numpy()
|
|
219
|
-
|
|
220
|
-
# return the gradients
|
|
221
|
-
with torch.enable_grad(): self.value = closure()
|
|
222
|
-
jac = params.ensure_grad_().grad.to_vec()
|
|
223
|
-
return jac.detach().cpu().numpy()
|
|
224
|
-
|
|
225
|
-
@torch.no_grad
|
|
226
|
-
def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
227
|
-
params = self.get_params()
|
|
228
|
-
|
|
229
|
-
x0 = params.to_vec().detach().cpu().numpy()
|
|
230
|
-
|
|
231
|
-
res = scipy.optimize.root(
|
|
232
|
-
partial(self._objective, params = params, closure = closure),
|
|
233
|
-
x0 = x0,
|
|
234
|
-
method=self.method,
|
|
235
|
-
tol=self.tol,
|
|
236
|
-
callback=self.callback,
|
|
237
|
-
options=self.options,
|
|
238
|
-
jac = self.jac,
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
242
|
-
return res.fun
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
class ScipyLeastSquaresOptimization(Optimizer):
|
|
246
|
-
"""Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
|
|
247
|
-
|
|
248
|
-
Args:
|
|
249
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
250
|
-
method (str | None, optional): _description_. Defaults to None.
|
|
251
|
-
tol (float | None, optional): _description_. Defaults to None.
|
|
252
|
-
callback (_type_, optional): _description_. Defaults to None.
|
|
253
|
-
options (_type_, optional): _description_. Defaults to None.
|
|
254
|
-
jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
|
|
255
|
-
"""
|
|
256
|
-
def __init__(
|
|
257
|
-
self,
|
|
258
|
-
params,
|
|
259
|
-
method='trf',
|
|
260
|
-
jac='autograd',
|
|
261
|
-
bounds=(-np.inf, np.inf),
|
|
262
|
-
ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
|
|
263
|
-
f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
|
|
264
|
-
jac_sparsity=None, max_nfev=None, verbose=0
|
|
265
|
-
):
|
|
266
|
-
super().__init__(params, {})
|
|
267
|
-
kwargs = locals().copy()
|
|
268
|
-
del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
|
|
269
|
-
self._kwargs = kwargs
|
|
270
|
-
|
|
271
|
-
self.jac = jac
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
275
|
-
# set params to x
|
|
276
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
277
|
-
|
|
278
|
-
# return the gradients
|
|
279
|
-
with torch.enable_grad(): self.value = closure()
|
|
280
|
-
jac = params.ensure_grad_().grad.to_vec()
|
|
281
|
-
return jac.numpy(force=True)
|
|
282
|
-
|
|
283
|
-
def _hess(self, x: np.ndarray, params: TensorList, closure):
|
|
284
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
285
|
-
with torch.enable_grad():
|
|
286
|
-
value = closure(False)
|
|
287
|
-
_, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
|
|
288
|
-
return H.numpy(force=True)
|
|
289
|
-
|
|
290
|
-
@torch.no_grad
|
|
291
|
-
def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
292
|
-
params = self.get_params()
|
|
293
|
-
|
|
294
|
-
x0 = params.to_vec().detach().cpu().numpy()
|
|
295
|
-
|
|
296
|
-
if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
|
|
297
|
-
else: jac = self.jac
|
|
298
|
-
|
|
299
|
-
res = scipy.optimize.least_squares(
|
|
300
|
-
partial(self._objective, params = params, closure = closure),
|
|
301
|
-
x0 = x0,
|
|
302
|
-
jac=jac, # type:ignore
|
|
303
|
-
**self._kwargs
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
307
|
-
return res.fun
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
class ScipyDE(Optimizer):
|
|
313
|
-
"""Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
|
|
314
|
-
so usually you would want to perform a single step. This also requires bounds to be specified.
|
|
315
|
-
|
|
316
|
-
Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
|
|
317
|
-
for all other args.
|
|
318
|
-
|
|
319
|
-
Args:
|
|
320
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
321
|
-
bounds (tuple[float,float], optional): tuple with lower and upper bounds.
|
|
322
|
-
DE requires bounds to be specified. Defaults to None.
|
|
323
|
-
|
|
324
|
-
other args:
|
|
325
|
-
refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
|
|
326
|
-
"""
|
|
327
|
-
def __init__(
|
|
328
|
-
self,
|
|
329
|
-
params,
|
|
330
|
-
lb: float,
|
|
331
|
-
ub: float,
|
|
332
|
-
strategy: Literal['best1bin', 'best1exp', 'rand1bin', 'rand1exp', 'rand2bin', 'rand2exp',
|
|
333
|
-
'randtobest1bin', 'randtobest1exp', 'currenttobest1bin', 'currenttobest1exp',
|
|
334
|
-
'best2exp', 'best2bin'] = 'best1bin',
|
|
335
|
-
maxiter: int = 1000,
|
|
336
|
-
popsize: int = 15,
|
|
337
|
-
tol: float = 0.01,
|
|
338
|
-
mutation = (0.5, 1),
|
|
339
|
-
recombination: float = 0.7,
|
|
340
|
-
seed = None,
|
|
341
|
-
callback = None,
|
|
342
|
-
disp: bool = False,
|
|
343
|
-
polish: bool = False,
|
|
344
|
-
init: str = 'latinhypercube',
|
|
345
|
-
atol: int = 0,
|
|
346
|
-
updating: str = 'immediate',
|
|
347
|
-
workers: int = 1,
|
|
348
|
-
constraints = (),
|
|
349
|
-
*,
|
|
350
|
-
integrality = None,
|
|
351
|
-
|
|
352
|
-
):
|
|
353
|
-
super().__init__(params, lb=lb, ub=ub)
|
|
354
|
-
|
|
355
|
-
kwargs = locals().copy()
|
|
356
|
-
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
357
|
-
self._kwargs = kwargs
|
|
358
|
-
|
|
359
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
360
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
361
|
-
return _ensure_float(closure(False))
|
|
362
|
-
|
|
363
|
-
@torch.no_grad
|
|
364
|
-
def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
365
|
-
params = self.get_params()
|
|
366
|
-
|
|
367
|
-
x0 = params.to_vec().detach().cpu().numpy()
|
|
368
|
-
|
|
369
|
-
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
370
|
-
bounds = []
|
|
371
|
-
for p, l, u in zip(params, lb, ub):
|
|
372
|
-
bounds.extend([(l, u)] * p.numel())
|
|
373
|
-
|
|
374
|
-
res = scipy.optimize.differential_evolution(
|
|
375
|
-
partial(self._objective, params = params, closure = closure),
|
|
376
|
-
x0 = x0,
|
|
377
|
-
bounds=bounds,
|
|
378
|
-
**self._kwargs
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
382
|
-
return res.fun
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
class ScipyDualAnnealing(Optimizer):
|
|
387
|
-
def __init__(
|
|
388
|
-
self,
|
|
389
|
-
params,
|
|
390
|
-
lb: float,
|
|
391
|
-
ub: float,
|
|
392
|
-
maxiter=1000,
|
|
393
|
-
minimizer_kwargs=None,
|
|
394
|
-
initial_temp=5230.0,
|
|
395
|
-
restart_temp_ratio=2.0e-5,
|
|
396
|
-
visit=2.62,
|
|
397
|
-
accept=-5.0,
|
|
398
|
-
maxfun=1e7,
|
|
399
|
-
rng=None,
|
|
400
|
-
no_local_search=False,
|
|
401
|
-
):
|
|
402
|
-
super().__init__(params, lb=lb, ub=ub)
|
|
403
|
-
|
|
404
|
-
kwargs = locals().copy()
|
|
405
|
-
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
406
|
-
self._kwargs = kwargs
|
|
407
|
-
|
|
408
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
409
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
410
|
-
return _ensure_float(closure(False))
|
|
411
|
-
|
|
412
|
-
@torch.no_grad
|
|
413
|
-
def step(self, closure: Closure):
|
|
414
|
-
params = self.get_params()
|
|
415
|
-
|
|
416
|
-
x0 = params.to_vec().detach().cpu().numpy()
|
|
417
|
-
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
418
|
-
bounds = []
|
|
419
|
-
for p, l, u in zip(params, lb, ub):
|
|
420
|
-
bounds.extend([(l, u)] * p.numel())
|
|
421
|
-
|
|
422
|
-
res = scipy.optimize.dual_annealing(
|
|
423
|
-
partial(self._objective, params = params, closure = closure),
|
|
424
|
-
x0 = x0,
|
|
425
|
-
bounds=bounds,
|
|
426
|
-
**self._kwargs
|
|
427
|
-
)
|
|
428
|
-
|
|
429
|
-
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
430
|
-
return res.fun
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
class ScipySHGO(Optimizer):
|
|
435
|
-
def __init__(
|
|
436
|
-
self,
|
|
437
|
-
params,
|
|
438
|
-
lb: float,
|
|
439
|
-
ub: float,
|
|
440
|
-
constraints = None,
|
|
441
|
-
n: int = 100,
|
|
442
|
-
iters: int = 1,
|
|
443
|
-
callback = None,
|
|
444
|
-
minimizer_kwargs = None,
|
|
445
|
-
options = None,
|
|
446
|
-
sampling_method: str = 'simplicial',
|
|
447
|
-
):
|
|
448
|
-
super().__init__(params, lb=lb, ub=ub)
|
|
449
|
-
|
|
450
|
-
kwargs = locals().copy()
|
|
451
|
-
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
452
|
-
self._kwargs = kwargs
|
|
453
|
-
|
|
454
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
455
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
456
|
-
return _ensure_float(closure(False))
|
|
457
|
-
|
|
458
|
-
@torch.no_grad
|
|
459
|
-
def step(self, closure: Closure):
|
|
460
|
-
params = self.get_params()
|
|
461
|
-
|
|
462
|
-
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
463
|
-
bounds = []
|
|
464
|
-
for p, l, u in zip(params, lb, ub):
|
|
465
|
-
bounds.extend([(l, u)] * p.numel())
|
|
466
|
-
|
|
467
|
-
res = scipy.optimize.shgo(
|
|
468
|
-
partial(self._objective, params = params, closure = closure),
|
|
469
|
-
bounds=bounds,
|
|
470
|
-
**self._kwargs
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
474
|
-
return res.fun
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
class ScipyDIRECT(Optimizer):
|
|
478
|
-
def __init__(
|
|
479
|
-
self,
|
|
480
|
-
params,
|
|
481
|
-
lb: float,
|
|
482
|
-
ub: float,
|
|
483
|
-
maxfun: int | None = 1000,
|
|
484
|
-
maxiter: int = 1000,
|
|
485
|
-
eps: float = 0.0001,
|
|
486
|
-
locally_biased: bool = True,
|
|
487
|
-
f_min: float = -np.inf,
|
|
488
|
-
f_min_rtol: float = 0.0001,
|
|
489
|
-
vol_tol: float = 1e-16,
|
|
490
|
-
len_tol: float = 0.000001,
|
|
491
|
-
callback = None,
|
|
492
|
-
):
|
|
493
|
-
super().__init__(params, lb=lb, ub=ub)
|
|
494
|
-
|
|
495
|
-
kwargs = locals().copy()
|
|
496
|
-
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
497
|
-
self._kwargs = kwargs
|
|
498
|
-
|
|
499
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
|
|
500
|
-
if self.raised: return np.inf
|
|
501
|
-
try:
|
|
502
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
503
|
-
return _ensure_float(closure(False))
|
|
504
|
-
except Exception as e:
|
|
505
|
-
# he he he ha, I found a way to make exceptions work in fcmaes and scipy direct
|
|
506
|
-
self.e = e
|
|
507
|
-
self.raised = True
|
|
508
|
-
return np.inf
|
|
509
|
-
|
|
510
|
-
@torch.no_grad
|
|
511
|
-
def step(self, closure: Closure):
|
|
512
|
-
self.raised = False
|
|
513
|
-
self.e = None
|
|
514
|
-
|
|
515
|
-
params = self.get_params()
|
|
516
|
-
|
|
517
|
-
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
518
|
-
bounds = []
|
|
519
|
-
for p, l, u in zip(params, lb, ub):
|
|
520
|
-
bounds.extend([(l, u)] * p.numel())
|
|
521
|
-
|
|
522
|
-
res = scipy.optimize.direct(
|
|
523
|
-
partial(self._objective, params=params, closure=closure),
|
|
524
|
-
bounds=bounds,
|
|
525
|
-
**self._kwargs
|
|
526
|
-
)
|
|
527
|
-
|
|
528
|
-
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
529
|
-
|
|
530
|
-
if self.e is not None: raise self.e from None
|
|
531
|
-
return res.fun
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
class ScipyBrute(Optimizer):
|
|
537
|
-
def __init__(
|
|
538
|
-
self,
|
|
539
|
-
params,
|
|
540
|
-
lb: float,
|
|
541
|
-
ub: float,
|
|
542
|
-
Ns: int = 20,
|
|
543
|
-
full_output: int = 0,
|
|
544
|
-
finish = scipy.optimize.fmin,
|
|
545
|
-
disp: bool = False,
|
|
546
|
-
workers: int = 1
|
|
547
|
-
):
|
|
548
|
-
super().__init__(params, lb=lb, ub=ub)
|
|
549
|
-
|
|
550
|
-
kwargs = locals().copy()
|
|
551
|
-
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
552
|
-
self._kwargs = kwargs
|
|
553
|
-
|
|
554
|
-
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
555
|
-
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
556
|
-
return _ensure_float(closure(False))
|
|
557
|
-
|
|
558
|
-
@torch.no_grad
|
|
559
|
-
def step(self, closure: Closure):
|
|
560
|
-
params = self.get_params()
|
|
561
|
-
|
|
562
|
-
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
563
|
-
bounds = []
|
|
564
|
-
for p, l, u in zip(params, lb, ub):
|
|
565
|
-
bounds.extend([(l, u)] * p.numel())
|
|
566
|
-
|
|
567
|
-
x0 = scipy.optimize.brute(
|
|
568
|
-
partial(self._objective, params = params, closure = closure),
|
|
569
|
-
ranges=bounds,
|
|
570
|
-
**self._kwargs
|
|
571
|
-
)
|
|
572
|
-
params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
from . import linear_operator
|
|
2
|
-
from .matrix_funcs import (
|
|
3
|
-
eigvals_func,
|
|
4
|
-
inv_sqrt_2x2,
|
|
5
|
-
matrix_power_eigh,
|
|
6
|
-
singular_vals_func,
|
|
7
|
-
x_inv,
|
|
8
|
-
)
|
|
9
|
-
from .orthogonalize import gram_schmidt
|
|
10
|
-
from .qr import qr_householder
|
|
11
|
-
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
|
|
12
|
-
from .svd import randomized_svd
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from collections.abc import Callable
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
def eigvals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
|
|
7
|
-
L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
|
|
8
|
-
L = fn(L)
|
|
9
|
-
return (Q * L.unsqueeze(-2)) @ Q.mH
|
|
10
|
-
|
|
11
|
-
def singular_vals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
|
|
12
|
-
U, S, V = torch.linalg.svd(A) # pylint:disable=not-callable
|
|
13
|
-
S = fn(S)
|
|
14
|
-
return (U * S.unsqueeze(-2)) @ V.mT
|
|
15
|
-
|
|
16
|
-
def matrix_power_eigh(A: torch.Tensor, pow:float):
|
|
17
|
-
L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
|
|
18
|
-
if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
|
|
19
|
-
return (Q * L.pow(pow).unsqueeze(-2)) @ Q.mH
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def inv_sqrt_2x2(A: torch.Tensor, force_pd: bool=False) -> torch.Tensor:
|
|
23
|
-
"""Inverse square root of a possibly batched 2x2 matrix using a general formula for 2x2 matrices so that this is way faster than torch linalg. I tried doing a hierarchical 2x2 preconditioning but it didn't work well."""
|
|
24
|
-
eps = torch.finfo(A.dtype).tiny * 2
|
|
25
|
-
|
|
26
|
-
a = A[..., 0, 0]
|
|
27
|
-
b = A[..., 0, 1]
|
|
28
|
-
c = A[..., 1, 0]
|
|
29
|
-
d = A[..., 1, 1]
|
|
30
|
-
|
|
31
|
-
det = (a * d).sub_(b * c)
|
|
32
|
-
trace = a + d
|
|
33
|
-
|
|
34
|
-
if force_pd:
|
|
35
|
-
# add smallest eigenvalue magnitude to diagonal to force PD
|
|
36
|
-
# could also abs or clip eigenvalues bc there is a formula for eigenvectors
|
|
37
|
-
term1 = trace/2
|
|
38
|
-
term2 = (trace.pow(2).div_(4).sub_(det)).clamp_(min=eps).sqrt_()
|
|
39
|
-
y1 = term1 + term2
|
|
40
|
-
y2 = term1 - term2
|
|
41
|
-
smallest_eigval = torch.minimum(y1, y2).neg_().clamp_(min=0) + eps
|
|
42
|
-
a = a+smallest_eigval
|
|
43
|
-
d = d+smallest_eigval
|
|
44
|
-
|
|
45
|
-
# recalculate det and trace witg new a and b
|
|
46
|
-
det = (a * d).sub_(b * c)
|
|
47
|
-
trace = a + d
|
|
48
|
-
|
|
49
|
-
s = (det.clamp(min=eps)).sqrt_()
|
|
50
|
-
|
|
51
|
-
tau_squared = trace + 2 * s
|
|
52
|
-
tau = (tau_squared.clamp(min=eps)).sqrt_()
|
|
53
|
-
|
|
54
|
-
denom = s * tau
|
|
55
|
-
|
|
56
|
-
coeff = (denom.clamp(min=eps)).reciprocal_().unsqueeze(-1).unsqueeze(-1)
|
|
57
|
-
|
|
58
|
-
row1 = torch.stack([d + s, -b], dim=-1)
|
|
59
|
-
row2 = torch.stack([-c, a + s], dim=-1)
|
|
60
|
-
M = torch.stack([row1, row2], dim=-2)
|
|
61
|
-
|
|
62
|
-
return coeff * M
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def x_inv(diag: torch.Tensor,antidiag: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
66
|
-
"""invert a matrix with diagonal and anti-diagonal non zero elements, with no checks that it is invertible"""
|
|
67
|
-
n = diag.shape[0]
|
|
68
|
-
if diag.dim() != 1 or antidiag.dim() != 1 or antidiag.shape[0] != n:
|
|
69
|
-
raise ValueError("Input tensors must be 1D and have the same size.")
|
|
70
|
-
if n == 0:
|
|
71
|
-
return torch.empty_like(diag), torch.empty_like(antidiag)
|
|
72
|
-
|
|
73
|
-
# opposite indexes
|
|
74
|
-
diag_rev = torch.flip(diag, dims=[0])
|
|
75
|
-
antidiag_rev = torch.flip(antidiag, dims=[0])
|
|
76
|
-
|
|
77
|
-
# determinants
|
|
78
|
-
# det_i = d[i] * d[n-1-i] - a[i] * a[n-1-i]
|
|
79
|
-
determinant_vec = diag * diag_rev - antidiag * antidiag_rev
|
|
80
|
-
|
|
81
|
-
# inverse diagonal elements: y_d[i] = d[n-1-i] / det_i
|
|
82
|
-
inv_diag_vec = diag_rev / determinant_vec
|
|
83
|
-
|
|
84
|
-
# inverse anti-diagonal elements: y_a[i] = -a[i] / det_i
|
|
85
|
-
inv_anti_diag_vec = -antidiag / determinant_vec
|
|
86
|
-
|
|
87
|
-
return inv_diag_vec, inv_anti_diag_vec
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
from typing import overload
|
|
2
|
-
import torch
|
|
3
|
-
from ..tensorlist import TensorList
|
|
4
|
-
|
|
5
|
-
@overload
|
|
6
|
-
def gram_schmidt(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
7
|
-
@overload
|
|
8
|
-
def gram_schmidt(x: TensorList, y: TensorList) -> tuple[TensorList, TensorList]: ...
|
|
9
|
-
def gram_schmidt(x, y):
|
|
10
|
-
"""makes two orthogonal vectors, only y is changed"""
|
|
11
|
-
min = torch.finfo(x.dtype).tiny * 2
|
|
12
|
-
return x, y - (x*y) / (x*x).clip(min=min)
|