torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -69,7 +69,7 @@ def _ensure_tensor(x):
|
|
|
69
69
|
inf = float('inf')
|
|
70
70
|
Closure = Callable[[bool], Any]
|
|
71
71
|
|
|
72
|
-
class
|
|
72
|
+
class NLOptWrapper(Optimizer):
|
|
73
73
|
"""Use nlopt as pytorch optimizer, with gradient supplied by pytorch autograd.
|
|
74
74
|
Note that this performs full minimization on each step,
|
|
75
75
|
so usually you would want to perform a single step, although performing multiple steps will refine the
|
|
@@ -96,9 +96,9 @@ class NLOptOptimizer(Optimizer):
|
|
|
96
96
|
self,
|
|
97
97
|
params,
|
|
98
98
|
algorithm: int | _ALGOS_LITERAL,
|
|
99
|
-
maxeval: int | None,
|
|
100
99
|
lb: float | None = None,
|
|
101
100
|
ub: float | None = None,
|
|
101
|
+
maxeval: int | None = 10000, # None can stall on some algos and because they are threaded C you can't even interrupt them
|
|
102
102
|
stopval: float | None = None,
|
|
103
103
|
ftol_rel: float | None = None,
|
|
104
104
|
ftol_abs: float | None = None,
|
|
@@ -122,22 +122,33 @@ class NLOptOptimizer(Optimizer):
|
|
|
122
122
|
self._last_loss = None
|
|
123
123
|
|
|
124
124
|
def _f(self, x: np.ndarray, grad: np.ndarray, closure, params: TensorList):
|
|
125
|
-
|
|
126
|
-
if t is None:
|
|
125
|
+
if self.raised:
|
|
127
126
|
if self.opt is not None: self.opt.force_stop()
|
|
128
|
-
return
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
127
|
+
return np.inf
|
|
128
|
+
try:
|
|
129
|
+
t = _ensure_tensor(x)
|
|
130
|
+
if t is None:
|
|
131
|
+
if self.opt is not None: self.opt.force_stop()
|
|
132
|
+
return None
|
|
133
|
+
params.from_vec_(t.to(params[0], copy=False))
|
|
134
|
+
if grad.size > 0:
|
|
135
|
+
with torch.enable_grad(): loss = closure()
|
|
136
|
+
self._last_loss = _ensure_float(loss)
|
|
137
|
+
grad[:] = params.ensure_grad_().grad.to_vec().reshape(grad.shape).detach().cpu().numpy()
|
|
138
|
+
return self._last_loss
|
|
139
|
+
|
|
140
|
+
self._last_loss = _ensure_float(closure(False))
|
|
134
141
|
return self._last_loss
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
142
|
+
except Exception as e:
|
|
143
|
+
self.e = e
|
|
144
|
+
self.raised = True
|
|
145
|
+
if self.opt is not None: self.opt.force_stop()
|
|
146
|
+
return np.inf
|
|
138
147
|
|
|
139
148
|
@torch.no_grad
|
|
140
149
|
def step(self, closure: Closure): # pylint: disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
150
|
+
self.e = None
|
|
151
|
+
self.raised = False
|
|
141
152
|
params = self.get_params()
|
|
142
153
|
|
|
143
154
|
# make bounds
|
|
@@ -175,6 +186,9 @@ class NLOptOptimizer(Optimizer):
|
|
|
175
186
|
except Exception as e:
|
|
176
187
|
raise e from None
|
|
177
188
|
|
|
189
|
+
if x is not None: params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
190
|
+
if self.e is not None: raise self.e from None
|
|
191
|
+
|
|
178
192
|
if self._last_loss is None or x is None: return closure(False)
|
|
179
|
-
|
|
193
|
+
|
|
180
194
|
return self._last_loss
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from collections import abc
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
import optuna
|
|
8
|
+
|
|
9
|
+
from ...utils import Optimizer
|
|
10
|
+
|
|
11
|
+
def silence_optuna():
|
|
12
|
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
|
13
|
+
|
|
14
|
+
def _ensure_float(x) -> float:
|
|
15
|
+
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
16
|
+
if isinstance(x, np.ndarray): return float(x.item())
|
|
17
|
+
return float(x)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OptunaSampler(Optimizer):
|
|
21
|
+
"""Optimize your next SOTA model using hyperparameter optimization.
|
|
22
|
+
|
|
23
|
+
Note - optuna is surprisingly scalable to large number of parameters (up to 10,000), despite literally requiring a for-loop because it only supports scalars. Default TPESampler is good for BBO. Maybe not for NNs...
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
27
|
+
lb (float): lower bounds.
|
|
28
|
+
ub (float): upper bounds.
|
|
29
|
+
sampler (optuna.samplers.BaseSampler | type[optuna.samplers.BaseSampler] | None, optional): sampler. Defaults to None.
|
|
30
|
+
silence (bool, optional): makes optuna not write a lot of very useful information to console. Defaults to True.
|
|
31
|
+
"""
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
params,
|
|
35
|
+
lb: float,
|
|
36
|
+
ub: float,
|
|
37
|
+
sampler: "optuna.samplers.BaseSampler | type[optuna.samplers.BaseSampler] | None" = None,
|
|
38
|
+
silence: bool = True,
|
|
39
|
+
):
|
|
40
|
+
if silence: silence_optuna()
|
|
41
|
+
super().__init__(params, lb=lb, ub=ub)
|
|
42
|
+
|
|
43
|
+
if isinstance(sampler, type): sampler = sampler()
|
|
44
|
+
self.sampler = sampler
|
|
45
|
+
self.study = None
|
|
46
|
+
|
|
47
|
+
@torch.no_grad
|
|
48
|
+
def step(self, closure):
|
|
49
|
+
|
|
50
|
+
params = self.get_params()
|
|
51
|
+
if self.study is None:
|
|
52
|
+
self.study = optuna.create_study(sampler=self.sampler)
|
|
53
|
+
|
|
54
|
+
# some optuna samplers use torch
|
|
55
|
+
with torch.enable_grad():
|
|
56
|
+
trial = self.study.ask()
|
|
57
|
+
|
|
58
|
+
suggested = []
|
|
59
|
+
for gi,g in enumerate(self.param_groups):
|
|
60
|
+
for pi,p in enumerate(g['params']):
|
|
61
|
+
lb, ub = g['lb'], g['ub']
|
|
62
|
+
suggested.extend(trial.suggest_float(f'g{gi}_p{pi}_w{i}', lb, ub) for i in range(p.numel()))
|
|
63
|
+
|
|
64
|
+
vec = torch.as_tensor(suggested).to(params[0])
|
|
65
|
+
params.from_vec_(vec)
|
|
66
|
+
|
|
67
|
+
loss = closure()
|
|
68
|
+
with torch.enable_grad(): self.study.tell(trial, loss)
|
|
69
|
+
|
|
70
|
+
return loss
|
|
@@ -11,9 +11,9 @@ from ...utils import Optimizer, TensorList
|
|
|
11
11
|
from ...utils.derivatives import jacobian_and_hessian_mat_wrt, jacobian_wrt
|
|
12
12
|
from ...modules.second_order.newton import tikhonov_
|
|
13
13
|
|
|
14
|
-
def _ensure_float(x):
|
|
14
|
+
def _ensure_float(x) -> float:
|
|
15
15
|
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
16
|
-
if isinstance(x, np.ndarray): return x.item()
|
|
16
|
+
if isinstance(x, np.ndarray): return float(x.item())
|
|
17
17
|
return float(x)
|
|
18
18
|
|
|
19
19
|
def _ensure_numpy(x):
|
|
@@ -139,9 +139,11 @@ class ScipyMinimize(Optimizer):
|
|
|
139
139
|
|
|
140
140
|
# make bounds
|
|
141
141
|
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
142
|
-
bounds =
|
|
143
|
-
|
|
144
|
-
bounds
|
|
142
|
+
bounds = None
|
|
143
|
+
if any(b is not None for b in lb) or any(b is not None for b in ub):
|
|
144
|
+
bounds = []
|
|
145
|
+
for p, l, u in zip(params, lb, ub):
|
|
146
|
+
bounds.extend([(l, u)] * p.numel())
|
|
145
147
|
|
|
146
148
|
if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
|
|
147
149
|
x0 = x0.astype(np.float64) # those methods error without this
|
|
@@ -265,7 +267,8 @@ class ScipyDE(Optimizer):
|
|
|
265
267
|
def __init__(
|
|
266
268
|
self,
|
|
267
269
|
params,
|
|
268
|
-
|
|
270
|
+
lb: float,
|
|
271
|
+
ub: float,
|
|
269
272
|
strategy: Literal['best1bin', 'best1exp', 'rand1bin', 'rand1exp', 'rand2bin', 'rand2exp',
|
|
270
273
|
'randtobest1bin', 'randtobest1exp', 'currenttobest1bin', 'currenttobest1exp',
|
|
271
274
|
'best2exp', 'best2bin'] = 'best1bin',
|
|
@@ -287,12 +290,11 @@ class ScipyDE(Optimizer):
|
|
|
287
290
|
integrality = None,
|
|
288
291
|
|
|
289
292
|
):
|
|
290
|
-
super().__init__(params,
|
|
293
|
+
super().__init__(params, lb=lb, ub=ub)
|
|
291
294
|
|
|
292
295
|
kwargs = locals().copy()
|
|
293
|
-
del kwargs['self'], kwargs['params'], kwargs['
|
|
296
|
+
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
294
297
|
self._kwargs = kwargs
|
|
295
|
-
self._lb, self._ub = bounds
|
|
296
298
|
|
|
297
299
|
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
298
300
|
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
@@ -303,7 +305,11 @@ class ScipyDE(Optimizer):
|
|
|
303
305
|
params = self.get_params()
|
|
304
306
|
|
|
305
307
|
x0 = params.to_vec().detach().cpu().numpy()
|
|
306
|
-
|
|
308
|
+
|
|
309
|
+
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
310
|
+
bounds = []
|
|
311
|
+
for p, l, u in zip(params, lb, ub):
|
|
312
|
+
bounds.extend([(l, u)] * p.numel())
|
|
307
313
|
|
|
308
314
|
res = scipy.optimize.differential_evolution(
|
|
309
315
|
partial(self._objective, params = params, closure = closure),
|
|
@@ -321,7 +327,8 @@ class ScipyDualAnnealing(Optimizer):
|
|
|
321
327
|
def __init__(
|
|
322
328
|
self,
|
|
323
329
|
params,
|
|
324
|
-
|
|
330
|
+
lb: float,
|
|
331
|
+
ub: float,
|
|
325
332
|
maxiter=1000,
|
|
326
333
|
minimizer_kwargs=None,
|
|
327
334
|
initial_temp=5230.0,
|
|
@@ -332,23 +339,25 @@ class ScipyDualAnnealing(Optimizer):
|
|
|
332
339
|
rng=None,
|
|
333
340
|
no_local_search=False,
|
|
334
341
|
):
|
|
335
|
-
super().__init__(params,
|
|
342
|
+
super().__init__(params, lb=lb, ub=ub)
|
|
336
343
|
|
|
337
344
|
kwargs = locals().copy()
|
|
338
|
-
del kwargs['self'], kwargs['params'], kwargs['
|
|
345
|
+
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
339
346
|
self._kwargs = kwargs
|
|
340
|
-
self._lb, self._ub = bounds
|
|
341
347
|
|
|
342
348
|
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
343
349
|
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
344
350
|
return _ensure_float(closure(False))
|
|
345
351
|
|
|
346
352
|
@torch.no_grad
|
|
347
|
-
def step(self, closure: Closure)
|
|
353
|
+
def step(self, closure: Closure):
|
|
348
354
|
params = self.get_params()
|
|
349
355
|
|
|
350
356
|
x0 = params.to_vec().detach().cpu().numpy()
|
|
351
|
-
|
|
357
|
+
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
358
|
+
bounds = []
|
|
359
|
+
for p, l, u in zip(params, lb, ub):
|
|
360
|
+
bounds.extend([(l, u)] * p.numel())
|
|
352
361
|
|
|
353
362
|
res = scipy.optimize.dual_annealing(
|
|
354
363
|
partial(self._objective, params = params, closure = closure),
|
|
@@ -360,3 +369,145 @@ class ScipyDualAnnealing(Optimizer):
|
|
|
360
369
|
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
361
370
|
return res.fun
|
|
362
371
|
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class ScipySHGO(Optimizer):
|
|
375
|
+
def __init__(
|
|
376
|
+
self,
|
|
377
|
+
params,
|
|
378
|
+
lb: float,
|
|
379
|
+
ub: float,
|
|
380
|
+
constraints = None,
|
|
381
|
+
n: int = 100,
|
|
382
|
+
iters: int = 1,
|
|
383
|
+
callback = None,
|
|
384
|
+
minimizer_kwargs = None,
|
|
385
|
+
options = None,
|
|
386
|
+
sampling_method: str = 'simplicial',
|
|
387
|
+
):
|
|
388
|
+
super().__init__(params, lb=lb, ub=ub)
|
|
389
|
+
|
|
390
|
+
kwargs = locals().copy()
|
|
391
|
+
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
392
|
+
self._kwargs = kwargs
|
|
393
|
+
|
|
394
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
395
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
396
|
+
return _ensure_float(closure(False))
|
|
397
|
+
|
|
398
|
+
@torch.no_grad
|
|
399
|
+
def step(self, closure: Closure):
|
|
400
|
+
params = self.get_params()
|
|
401
|
+
|
|
402
|
+
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
403
|
+
bounds = []
|
|
404
|
+
for p, l, u in zip(params, lb, ub):
|
|
405
|
+
bounds.extend([(l, u)] * p.numel())
|
|
406
|
+
|
|
407
|
+
res = scipy.optimize.shgo(
|
|
408
|
+
partial(self._objective, params = params, closure = closure),
|
|
409
|
+
bounds=bounds,
|
|
410
|
+
**self._kwargs
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
414
|
+
return res.fun
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class ScipyDIRECT(Optimizer):
|
|
418
|
+
def __init__(
|
|
419
|
+
self,
|
|
420
|
+
params,
|
|
421
|
+
lb: float,
|
|
422
|
+
ub: float,
|
|
423
|
+
maxfun: int | None = 1000,
|
|
424
|
+
maxiter: int = 1000,
|
|
425
|
+
eps: float = 0.0001,
|
|
426
|
+
locally_biased: bool = True,
|
|
427
|
+
f_min: float = -np.inf,
|
|
428
|
+
f_min_rtol: float = 0.0001,
|
|
429
|
+
vol_tol: float = 1e-16,
|
|
430
|
+
len_tol: float = 0.000001,
|
|
431
|
+
callback = None,
|
|
432
|
+
):
|
|
433
|
+
super().__init__(params, lb=lb, ub=ub)
|
|
434
|
+
|
|
435
|
+
kwargs = locals().copy()
|
|
436
|
+
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
437
|
+
self._kwargs = kwargs
|
|
438
|
+
|
|
439
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
|
|
440
|
+
if self.raised: return np.inf
|
|
441
|
+
try:
|
|
442
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
443
|
+
return _ensure_float(closure(False))
|
|
444
|
+
except Exception as e:
|
|
445
|
+
# he he he ha, I found a way to make exceptions work in fcmaes and scipy direct
|
|
446
|
+
self.e = e
|
|
447
|
+
self.raised = True
|
|
448
|
+
return np.inf
|
|
449
|
+
|
|
450
|
+
@torch.no_grad
|
|
451
|
+
def step(self, closure: Closure):
|
|
452
|
+
self.raised = False
|
|
453
|
+
self.e = None
|
|
454
|
+
|
|
455
|
+
params = self.get_params()
|
|
456
|
+
|
|
457
|
+
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
458
|
+
bounds = []
|
|
459
|
+
for p, l, u in zip(params, lb, ub):
|
|
460
|
+
bounds.extend([(l, u)] * p.numel())
|
|
461
|
+
|
|
462
|
+
res = scipy.optimize.direct(
|
|
463
|
+
partial(self._objective, params=params, closure=closure),
|
|
464
|
+
bounds=bounds,
|
|
465
|
+
**self._kwargs
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
469
|
+
|
|
470
|
+
if self.e is not None: raise self.e from None
|
|
471
|
+
return res.fun
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class ScipyBrute(Optimizer):
|
|
477
|
+
def __init__(
|
|
478
|
+
self,
|
|
479
|
+
params,
|
|
480
|
+
lb: float,
|
|
481
|
+
ub: float,
|
|
482
|
+
Ns: int = 20,
|
|
483
|
+
full_output: int = 0,
|
|
484
|
+
finish = scipy.optimize.fmin,
|
|
485
|
+
disp: bool = False,
|
|
486
|
+
workers: int = 1
|
|
487
|
+
):
|
|
488
|
+
super().__init__(params, lb=lb, ub=ub)
|
|
489
|
+
|
|
490
|
+
kwargs = locals().copy()
|
|
491
|
+
del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
|
|
492
|
+
self._kwargs = kwargs
|
|
493
|
+
|
|
494
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
495
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
496
|
+
return _ensure_float(closure(False))
|
|
497
|
+
|
|
498
|
+
@torch.no_grad
|
|
499
|
+
def step(self, closure: Closure):
|
|
500
|
+
params = self.get_params()
|
|
501
|
+
|
|
502
|
+
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
503
|
+
bounds = []
|
|
504
|
+
for p, l, u in zip(params, lb, ub):
|
|
505
|
+
bounds.extend([(l, u)] * p.numel())
|
|
506
|
+
|
|
507
|
+
x0 = scipy.optimize.brute(
|
|
508
|
+
partial(self._objective, params = params, closure = closure),
|
|
509
|
+
ranges=bounds,
|
|
510
|
+
**self._kwargs
|
|
511
|
+
)
|
|
512
|
+
params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
513
|
+
return None
|
torchzero/utils/__init__.py
CHANGED
|
@@ -9,11 +9,7 @@ from .optimizer import (
|
|
|
9
9
|
get_group_vals,
|
|
10
10
|
get_params,
|
|
11
11
|
get_state_vals,
|
|
12
|
-
|
|
13
|
-
grad_vec_at_params,
|
|
14
|
-
loss_at_params,
|
|
15
|
-
loss_grad_at_params,
|
|
16
|
-
loss_grad_vec_at_params,
|
|
12
|
+
unpack_states,
|
|
17
13
|
)
|
|
18
14
|
from .params import (
|
|
19
15
|
Params,
|
|
@@ -22,6 +18,6 @@ from .params import (
|
|
|
22
18
|
_copy_param_groups,
|
|
23
19
|
_make_param_groups,
|
|
24
20
|
)
|
|
25
|
-
from .python_tools import flatten, generic_eq, reduce_dim
|
|
26
|
-
from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like
|
|
21
|
+
from .python_tools import flatten, generic_eq, generic_ne, reduce_dim, unpack_dicts
|
|
22
|
+
from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like, generic_finfo_eps
|
|
27
23
|
from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_
|
torchzero/utils/derivatives.py
CHANGED
|
@@ -2,6 +2,7 @@ from collections.abc import Iterable, Sequence
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
import torch.autograd.forward_ad as fwAD
|
|
5
|
+
from typing import Literal
|
|
5
6
|
|
|
6
7
|
from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
|
|
7
8
|
|
|
@@ -157,7 +158,7 @@ def hessian_mat(
|
|
|
157
158
|
method="func",
|
|
158
159
|
vectorize=False,
|
|
159
160
|
outer_jacobian_strategy="reverse-mode",
|
|
160
|
-
):
|
|
161
|
+
) -> torch.Tensor:
|
|
161
162
|
"""
|
|
162
163
|
returns hessian matrix for parameters (as if they were flattened and concatenated into a vector).
|
|
163
164
|
|
|
@@ -189,7 +190,7 @@ def hessian_mat(
|
|
|
189
190
|
return loss
|
|
190
191
|
|
|
191
192
|
if method == 'func':
|
|
192
|
-
return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph))
|
|
193
|
+
return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph)) # pyright:ignore[reportReturnType]
|
|
193
194
|
|
|
194
195
|
if method == 'autograd.functional':
|
|
195
196
|
return torch.autograd.functional.hessian(
|
|
@@ -198,7 +199,7 @@ def hessian_mat(
|
|
|
198
199
|
create_graph=create_graph,
|
|
199
200
|
vectorize=vectorize,
|
|
200
201
|
outer_jacobian_strategy=outer_jacobian_strategy,
|
|
201
|
-
)
|
|
202
|
+
) # pyright:ignore[reportReturnType]
|
|
202
203
|
raise ValueError(method)
|
|
203
204
|
|
|
204
205
|
def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -510,4 +511,4 @@ def hvp_fd_forward(
|
|
|
510
511
|
torch._foreach_div_(hvp_, h)
|
|
511
512
|
|
|
512
513
|
if normalize: torch._foreach_mul_(hvp_, vec_norm)
|
|
513
|
-
return loss, hvp_
|
|
514
|
+
return loss, hvp_
|
|
@@ -2,4 +2,4 @@ from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix
|
|
|
2
2
|
from .orthogonalize import gram_schmidt
|
|
3
3
|
from .qr import qr_householder
|
|
4
4
|
from .svd import randomized_svd
|
|
5
|
-
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
|
|
5
|
+
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve, steihaug_toint_cg
|