torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -4,7 +4,7 @@ from typing import Literal
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
@torch.no_grad
|
|
@@ -14,7 +14,7 @@ def weight_decay_(
|
|
|
14
14
|
weight_decay: float | NumberList,
|
|
15
15
|
ord: int = 2
|
|
16
16
|
):
|
|
17
|
-
"""returns
|
|
17
|
+
"""modifies in-place and returns ``grad_``."""
|
|
18
18
|
if ord == 1: return grad_.add_(params.sign().mul_(weight_decay))
|
|
19
19
|
if ord == 2: return grad_.add_(params.mul(weight_decay))
|
|
20
20
|
if ord - 1 % 2 != 0: return grad_.add_(params.pow(ord-1).mul_(weight_decay))
|
|
@@ -29,39 +29,38 @@ class WeightDecay(Transform):
|
|
|
29
29
|
ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
|
|
30
30
|
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
31
31
|
|
|
32
|
-
Examples:
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
)
|
|
32
|
+
### Examples:
|
|
33
|
+
|
|
34
|
+
Adam with non-decoupled weight decay
|
|
35
|
+
```python
|
|
36
|
+
opt = tz.Modular(
|
|
37
|
+
model.parameters(),
|
|
38
|
+
tz.m.WeightDecay(1e-3),
|
|
39
|
+
tz.m.Adam(),
|
|
40
|
+
tz.m.LR(1e-3)
|
|
41
|
+
)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
Adam with decoupled weight decay that still scales with learning rate
|
|
45
|
+
```python
|
|
46
|
+
|
|
47
|
+
opt = tz.Modular(
|
|
48
|
+
model.parameters(),
|
|
49
|
+
tz.m.Adam(),
|
|
50
|
+
tz.m.WeightDecay(1e-3),
|
|
51
|
+
tz.m.LR(1e-3)
|
|
52
|
+
)
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Adam with fully decoupled weight decay that doesn't scale with learning rate
|
|
56
|
+
```python
|
|
57
|
+
opt = tz.Modular(
|
|
58
|
+
model.parameters(),
|
|
59
|
+
tz.m.Adam(),
|
|
60
|
+
tz.m.LR(1e-3),
|
|
61
|
+
tz.m.WeightDecay(1e-6)
|
|
62
|
+
)
|
|
63
|
+
```
|
|
65
64
|
|
|
66
65
|
"""
|
|
67
66
|
def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
|
|
@@ -77,7 +76,7 @@ class WeightDecay(Transform):
|
|
|
77
76
|
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
|
|
78
77
|
|
|
79
78
|
class RelativeWeightDecay(Transform):
|
|
80
|
-
"""Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of
|
|
79
|
+
"""Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
|
|
81
80
|
|
|
82
81
|
Args:
|
|
83
82
|
weight_decay (float): relative weight decay scale.
|
|
@@ -85,40 +84,42 @@ class RelativeWeightDecay(Transform):
|
|
|
85
84
|
norm_input (str, optional):
|
|
86
85
|
determines what should weight decay be relative to. "update", "grad" or "params".
|
|
87
86
|
Defaults to "update".
|
|
87
|
+
metric (Ords, optional):
|
|
88
|
+
metric (norm, etc) that weight decay should be relative to.
|
|
89
|
+
defaults to 'mad' (mean absolute deviation).
|
|
88
90
|
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
89
91
|
|
|
90
|
-
Examples:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
112
|
-
|
|
92
|
+
### Examples:
|
|
93
|
+
|
|
94
|
+
Adam with non-decoupled relative weight decay
|
|
95
|
+
```python
|
|
96
|
+
opt = tz.Modular(
|
|
97
|
+
model.parameters(),
|
|
98
|
+
tz.m.RelativeWeightDecay(1e-1),
|
|
99
|
+
tz.m.Adam(),
|
|
100
|
+
tz.m.LR(1e-3)
|
|
101
|
+
)
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
Adam with decoupled relative weight decay
|
|
105
|
+
```python
|
|
106
|
+
opt = tz.Modular(
|
|
107
|
+
model.parameters(),
|
|
108
|
+
tz.m.Adam(),
|
|
109
|
+
tz.m.RelativeWeightDecay(1e-1),
|
|
110
|
+
tz.m.LR(1e-3)
|
|
111
|
+
)
|
|
112
|
+
```
|
|
113
113
|
"""
|
|
114
114
|
def __init__(
|
|
115
115
|
self,
|
|
116
116
|
weight_decay: float = 0.1,
|
|
117
117
|
ord: int = 2,
|
|
118
118
|
norm_input: Literal["update", "grad", "params"] = "update",
|
|
119
|
+
metric: Metrics = 'mad',
|
|
119
120
|
target: Target = "update",
|
|
120
121
|
):
|
|
121
|
-
defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
|
|
122
|
+
defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
|
|
122
123
|
super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
|
|
123
124
|
|
|
124
125
|
@torch.no_grad
|
|
@@ -127,6 +128,7 @@ class RelativeWeightDecay(Transform):
|
|
|
127
128
|
|
|
128
129
|
ord = settings[0]['ord']
|
|
129
130
|
norm_input = settings[0]['norm_input']
|
|
131
|
+
metric = settings[0]['metric']
|
|
130
132
|
|
|
131
133
|
if norm_input == 'update': src = TensorList(tensors)
|
|
132
134
|
elif norm_input == 'grad':
|
|
@@ -137,9 +139,8 @@ class RelativeWeightDecay(Transform):
|
|
|
137
139
|
else:
|
|
138
140
|
raise ValueError(norm_input)
|
|
139
141
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * mean_abs, ord)
|
|
142
|
+
norm = src.global_metric(metric)
|
|
143
|
+
return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)
|
|
143
144
|
|
|
144
145
|
|
|
145
146
|
@torch.no_grad
|
|
@@ -162,7 +163,7 @@ class DirectWeightDecay(Module):
|
|
|
162
163
|
@torch.no_grad
|
|
163
164
|
def step(self, var):
|
|
164
165
|
weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
|
|
165
|
-
ord = self.
|
|
166
|
+
ord = self.defaults['ord']
|
|
166
167
|
|
|
167
168
|
decay_weights_(var.params, weight_decay, ord)
|
|
168
169
|
return var
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .cd import CD
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
import warnings
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Module
|
|
11
|
+
from ...utils import NumberList, TensorList
|
|
12
|
+
|
|
13
|
+
class CD(Module):
|
|
14
|
+
"""Coordinate descent. Proposes a descent direction along a single coordinate.
|
|
15
|
+
A line search such as ``tz.m.ScipyMinimizeScalar(maxiter=8)`` or a fixed step size can be used after this.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
h (float, optional): finite difference step size. Defaults to 1e-3.
|
|
19
|
+
grad (bool, optional):
|
|
20
|
+
if True, scales direction by gradient estimate. If False, the scale is fixed to 1. Defaults to True.
|
|
21
|
+
adaptive (bool, optional):
|
|
22
|
+
whether to adapt finite difference step size, this requires an additional buffer. Defaults to True.
|
|
23
|
+
index (str, optional):
|
|
24
|
+
index selection strategy.
|
|
25
|
+
- "cyclic" - repeatedly cycles through each coordinate, e.g. ``1,2,3,1,2,3,...``.
|
|
26
|
+
- "cyclic2" - cycles forward and then backward, e.g ``1,2,3,3,2,1,1,2,3,...`` (default).
|
|
27
|
+
- "random" - picks coordinate randomly.
|
|
28
|
+
threepoint (bool, optional):
|
|
29
|
+
whether to use three points (three function evaluatins) to determine descent direction.
|
|
30
|
+
if False, uses two points, but then ``adaptive`` can't be used. Defaults to True.
|
|
31
|
+
"""
|
|
32
|
+
def __init__(self, h:float=1e-3, grad:bool=True, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
|
|
33
|
+
defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
|
|
34
|
+
super().__init__(defaults)
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def step(self, var):
|
|
38
|
+
closure = var.closure
|
|
39
|
+
if closure is None:
|
|
40
|
+
raise RuntimeError("CD requires closure")
|
|
41
|
+
|
|
42
|
+
params = TensorList(var.params)
|
|
43
|
+
ndim = params.global_numel()
|
|
44
|
+
|
|
45
|
+
grad_step_size = self.defaults['grad']
|
|
46
|
+
adaptive = self.defaults['adaptive']
|
|
47
|
+
index_strategy = self.defaults['index']
|
|
48
|
+
h = self.defaults['h']
|
|
49
|
+
threepoint = self.defaults['threepoint']
|
|
50
|
+
|
|
51
|
+
# ------------------------------ determine index ----------------------------- #
|
|
52
|
+
if index_strategy == 'cyclic':
|
|
53
|
+
idx = self.global_state.get('idx', 0) % ndim
|
|
54
|
+
self.global_state['idx'] = idx + 1
|
|
55
|
+
|
|
56
|
+
elif index_strategy == 'cyclic2':
|
|
57
|
+
idx = self.global_state.get('idx', 0)
|
|
58
|
+
self.global_state['idx'] = idx + 1
|
|
59
|
+
if idx >= ndim * 2:
|
|
60
|
+
idx = self.global_state['idx'] = 0
|
|
61
|
+
if idx >= ndim:
|
|
62
|
+
idx = (2*ndim - idx) - 1
|
|
63
|
+
|
|
64
|
+
elif index_strategy == 'random':
|
|
65
|
+
if 'generator' not in self.global_state:
|
|
66
|
+
self.global_state['generator'] = random.Random(0)
|
|
67
|
+
generator = self.global_state['generator']
|
|
68
|
+
idx = generator.randrange(0, ndim)
|
|
69
|
+
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(index_strategy)
|
|
72
|
+
|
|
73
|
+
# -------------------------- find descent direction -------------------------- #
|
|
74
|
+
h_vec = None
|
|
75
|
+
if adaptive:
|
|
76
|
+
if threepoint:
|
|
77
|
+
h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, h), cls=TensorList)
|
|
78
|
+
h = float(h_vec.flat_get(idx))
|
|
79
|
+
else:
|
|
80
|
+
warnings.warn("CD adaptive=True only works with threepoint=True")
|
|
81
|
+
|
|
82
|
+
f_0 = var.get_loss(False)
|
|
83
|
+
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
84
|
+
f_p = closure(False)
|
|
85
|
+
|
|
86
|
+
# -------------------------------- threepoint -------------------------------- #
|
|
87
|
+
if threepoint:
|
|
88
|
+
params.flat_set_lambda_(idx, lambda x: x - 2*h)
|
|
89
|
+
f_n = closure(False)
|
|
90
|
+
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
91
|
+
|
|
92
|
+
if adaptive:
|
|
93
|
+
assert h_vec is not None
|
|
94
|
+
if f_0 <= f_p and f_0 <= f_n:
|
|
95
|
+
h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
|
|
96
|
+
else:
|
|
97
|
+
if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
|
|
98
|
+
h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))
|
|
99
|
+
|
|
100
|
+
if grad_step_size:
|
|
101
|
+
alpha = (f_p - f_n) / (2*h)
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
if f_0 < f_p and f_0 < f_n: alpha = 0
|
|
105
|
+
elif f_p < f_n: alpha = -1
|
|
106
|
+
else: alpha = 1
|
|
107
|
+
|
|
108
|
+
# --------------------------------- twopoint --------------------------------- #
|
|
109
|
+
else:
|
|
110
|
+
params.flat_set_lambda_(idx, lambda x: x - h)
|
|
111
|
+
if grad_step_size:
|
|
112
|
+
alpha = (f_p - f_0) / h
|
|
113
|
+
else:
|
|
114
|
+
if f_p < f_0: alpha = -1
|
|
115
|
+
else: alpha = 1
|
|
116
|
+
|
|
117
|
+
# ----------------------------- create the update ---------------------------- #
|
|
118
|
+
update = params.zeros_like()
|
|
119
|
+
update.flat_set_(idx, alpha)
|
|
120
|
+
var.update = update
|
|
121
|
+
return var
|
|
122
|
+
|
torchzero/optim/root.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""WIP, untested"""
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
import torch
|
|
6
|
+
from ..modules.higher_order.multipoint import sixth_order_im1, sixth_order_p6, _solve
|
|
7
|
+
|
|
8
|
+
def make_evaluate(f: Callable[[torch.Tensor], torch.Tensor]):
|
|
9
|
+
def evaluate(x, order) -> tuple[torch.Tensor, ...]:
|
|
10
|
+
"""order=0 - returns (f,), order=1 - returns (f, J), order=2 - returns (f, J, H), etc."""
|
|
11
|
+
n = x.numel()
|
|
12
|
+
|
|
13
|
+
if order == 0:
|
|
14
|
+
f_x = f(x)
|
|
15
|
+
return (f_x, )
|
|
16
|
+
|
|
17
|
+
x.requires_grad_()
|
|
18
|
+
with torch.enable_grad():
|
|
19
|
+
f_x = f(x)
|
|
20
|
+
I = torch.eye(n, device=x.device, dtype=x.dtype),
|
|
21
|
+
g_x = torch.autograd.grad(f_x, x, I, create_graph=order!=1, is_grads_batched=True)[0]
|
|
22
|
+
ret = [f_x, g_x]
|
|
23
|
+
T = g_x
|
|
24
|
+
|
|
25
|
+
# get all derivative up to order
|
|
26
|
+
for o in range(2, order + 1):
|
|
27
|
+
is_last = o == order
|
|
28
|
+
I = torch.eye(T.numel(), device=x.device, dtype=x.dtype),
|
|
29
|
+
T = torch.autograd.grad(T.ravel(), x, I, create_graph=not is_last, is_grads_batched=True)[0]
|
|
30
|
+
ret.append(T.view(n, n, *T.shape[1:]))
|
|
31
|
+
|
|
32
|
+
return tuple(ret)
|
|
33
|
+
|
|
34
|
+
return evaluate
|
|
35
|
+
|
|
36
|
+
class RootBase:
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def one_iteration(
|
|
39
|
+
self,
|
|
40
|
+
x: torch.Tensor,
|
|
41
|
+
evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
|
|
42
|
+
) -> torch.Tensor:
|
|
43
|
+
""""""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ---------------------------------- methods --------------------------------- #
|
|
47
|
+
def newton(x:torch.Tensor, f_j, lstsq:bool=False):
|
|
48
|
+
f_x, G_x = f_j(x)
|
|
49
|
+
return x - _solve(G_x, f_x, lstsq=lstsq)
|
|
50
|
+
|
|
51
|
+
class Newton(RootBase):
|
|
52
|
+
def __init__(self, lstsq: bool=False): self.lstsq = lstsq
|
|
53
|
+
def one_iteration(self, x, evaluate): return newton(x, evaluate, self.lstsq)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class SixthOrderP6(RootBase):
|
|
57
|
+
"""sixth-order iterative method
|
|
58
|
+
|
|
59
|
+
Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
|
|
60
|
+
"""
|
|
61
|
+
def __init__(self, lstsq: bool=False): self.lstsq = lstsq
|
|
62
|
+
def one_iteration(self, x, evaluate):
|
|
63
|
+
def f(x): return evaluate(x, 0)[0]
|
|
64
|
+
def f_j(x): return evaluate(x, 1)
|
|
65
|
+
return sixth_order_p6(x, f, f_j, self.lstsq)
|
torchzero/optim/utility/split.py
CHANGED
|
@@ -11,12 +11,12 @@ class Split(torch.optim.Optimizer):
|
|
|
11
11
|
|
|
12
12
|
Example:
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
14
|
+
```python
|
|
15
|
+
opt = Split(
|
|
16
|
+
torch.optim.Adam(model.encoder.parameters(), lr=0.001),
|
|
17
|
+
torch.optim.SGD(model.decoder.parameters(), lr=0.1)
|
|
18
|
+
)
|
|
19
|
+
```
|
|
20
20
|
"""
|
|
21
21
|
def __init__(self, *optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer]):
|
|
22
22
|
all_params = []
|
|
@@ -25,14 +25,14 @@ class Split(torch.optim.Optimizer):
|
|
|
25
25
|
# gather all params in case user tries to access them from this object
|
|
26
26
|
for i,opt in enumerate(self.optimizers):
|
|
27
27
|
for p in get_params(opt.param_groups, 'all', list):
|
|
28
|
-
if p not in all_params: all_params.append(p)
|
|
28
|
+
if id(p) not in [id(pr) for pr in all_params]: all_params.append(p)
|
|
29
29
|
else: warnings.warn(
|
|
30
30
|
f'optimizers[{i}] {opt.__class__.__name__} has some duplicate parameters '
|
|
31
31
|
'that are also in previous optimizers. They will be updated multiple times.')
|
|
32
32
|
|
|
33
33
|
super().__init__(all_params, {})
|
|
34
34
|
|
|
35
|
-
def step(self, closure: Callable | None = None):
|
|
35
|
+
def step(self, closure: Callable | None = None): # pyright:ignore[reportIncompatibleMethodOverride]
|
|
36
36
|
loss = None
|
|
37
37
|
|
|
38
38
|
# if closure provided, populate grad, otherwise each optimizer will call closure separately
|
|
@@ -2,11 +2,12 @@ from collections.abc import Callable
|
|
|
2
2
|
from functools import partial
|
|
3
3
|
from typing import Any, Literal
|
|
4
4
|
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
5
8
|
import fcmaes
|
|
6
9
|
import fcmaes.optimizer
|
|
7
10
|
import fcmaes.retry
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
11
|
|
|
11
12
|
from ...utils import Optimizer, TensorList
|
|
12
13
|
|
|
@@ -75,8 +75,6 @@ class NLOptWrapper(Optimizer):
|
|
|
75
75
|
so usually you would want to perform a single step, although performing multiple steps will refine the
|
|
76
76
|
solution.
|
|
77
77
|
|
|
78
|
-
Some algorithms are buggy with numpy>=2.
|
|
79
|
-
|
|
80
78
|
Args:
|
|
81
79
|
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
82
80
|
algorithm (int | _ALGOS_LITERAL): optimization algorithm from https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/
|
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
|
|
7
7
|
import optuna
|
|
8
8
|
|
|
9
|
-
from ...utils import Optimizer
|
|
9
|
+
from ...utils import Optimizer, totensor, tofloat
|
|
10
10
|
|
|
11
11
|
def silence_optuna():
|
|
12
12
|
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
|
@@ -65,6 +65,6 @@ class OptunaSampler(Optimizer):
|
|
|
65
65
|
params.from_vec_(vec)
|
|
66
66
|
|
|
67
67
|
loss = closure()
|
|
68
|
-
with torch.enable_grad(): self.study.tell(trial, loss)
|
|
68
|
+
with torch.enable_grad(): self.study.tell(trial, tofloat(torch.nan_to_num(totensor(loss), 1e32)))
|
|
69
69
|
|
|
70
70
|
return loss
|
|
@@ -4,12 +4,17 @@ from functools import partial
|
|
|
4
4
|
from typing import Any, Literal
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
-
import scipy.optimize
|
|
8
7
|
import torch
|
|
9
8
|
|
|
9
|
+
import scipy.optimize
|
|
10
|
+
|
|
10
11
|
from ...utils import Optimizer, TensorList
|
|
11
|
-
from ...utils.derivatives import
|
|
12
|
-
|
|
12
|
+
from ...utils.derivatives import (
|
|
13
|
+
flatten_jacobian,
|
|
14
|
+
jacobian_and_hessian_mat_wrt,
|
|
15
|
+
jacobian_wrt,
|
|
16
|
+
)
|
|
17
|
+
|
|
13
18
|
|
|
14
19
|
def _ensure_float(x) -> float:
|
|
15
20
|
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
@@ -21,14 +26,6 @@ def _ensure_numpy(x):
|
|
|
21
26
|
if isinstance(x, np.ndarray): return x
|
|
22
27
|
return np.array(x)
|
|
23
28
|
|
|
24
|
-
def matrix_clamp(H: torch.Tensor, reg: float):
|
|
25
|
-
try:
|
|
26
|
-
eigvals, eigvecs = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
27
|
-
eigvals.clamp_(min=reg)
|
|
28
|
-
return eigvecs @ torch.diag(eigvals) @ eigvecs.mH
|
|
29
|
-
except Exception:
|
|
30
|
-
return H
|
|
31
|
-
|
|
32
29
|
Closure = Callable[[bool], Any]
|
|
33
30
|
|
|
34
31
|
class ScipyMinimize(Optimizer):
|
|
@@ -76,8 +73,6 @@ class ScipyMinimize(Optimizer):
|
|
|
76
73
|
options = None,
|
|
77
74
|
jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
|
|
78
75
|
hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
|
|
79
|
-
tikhonov: float | None = 0,
|
|
80
|
-
min_eigval: float | None = None,
|
|
81
76
|
):
|
|
82
77
|
defaults = dict(lb=lb, ub=ub)
|
|
83
78
|
super().__init__(params, defaults)
|
|
@@ -85,12 +80,10 @@ class ScipyMinimize(Optimizer):
|
|
|
85
80
|
self.constraints = constraints
|
|
86
81
|
self.tol = tol
|
|
87
82
|
self.callback = callback
|
|
88
|
-
self.min_eigval = min_eigval
|
|
89
83
|
self.options = options
|
|
90
84
|
|
|
91
85
|
self.jac = jac
|
|
92
86
|
self.hess = hess
|
|
93
|
-
self.tikhonov: float | None = tikhonov
|
|
94
87
|
|
|
95
88
|
self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
|
|
96
89
|
'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
|
|
@@ -111,9 +104,7 @@ class ScipyMinimize(Optimizer):
|
|
|
111
104
|
with torch.enable_grad():
|
|
112
105
|
value = closure(False)
|
|
113
106
|
_, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
|
|
114
|
-
|
|
115
|
-
if self.min_eigval is not None: H = matrix_clamp(H, self.min_eigval)
|
|
116
|
-
return H.detach().cpu().numpy()
|
|
107
|
+
return H.numpy(force=True)
|
|
117
108
|
|
|
118
109
|
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
119
110
|
# set params to x
|
|
@@ -122,7 +113,10 @@ class ScipyMinimize(Optimizer):
|
|
|
122
113
|
# return value and maybe gradients
|
|
123
114
|
if self.use_jac_autograd:
|
|
124
115
|
with torch.enable_grad(): value = _ensure_float(closure())
|
|
125
|
-
|
|
116
|
+
grad = params.ensure_grad_().grad.to_vec().numpy(force=True)
|
|
117
|
+
# slsqp requires float64
|
|
118
|
+
if self.method.lower() == 'slsqp': grad = grad.astype(np.float64)
|
|
119
|
+
return value, grad
|
|
126
120
|
return _ensure_float(closure(False))
|
|
127
121
|
|
|
128
122
|
@torch.no_grad
|
|
@@ -135,7 +129,7 @@ class ScipyMinimize(Optimizer):
|
|
|
135
129
|
else: hess = None
|
|
136
130
|
else: hess = self.hess
|
|
137
131
|
|
|
138
|
-
x0 = params.to_vec().
|
|
132
|
+
x0 = params.to_vec().numpy(force=True)
|
|
139
133
|
|
|
140
134
|
# make bounds
|
|
141
135
|
lb, ub = self.group_vals('lb', 'ub', cls=list)
|
|
@@ -167,7 +161,7 @@ class ScipyMinimize(Optimizer):
|
|
|
167
161
|
|
|
168
162
|
|
|
169
163
|
class ScipyRootOptimization(Optimizer):
|
|
170
|
-
"""Optimization via using scipy.root on gradients, mainly for experimenting!
|
|
164
|
+
"""Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
|
|
171
165
|
|
|
172
166
|
Args:
|
|
173
167
|
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
@@ -248,6 +242,72 @@ class ScipyRootOptimization(Optimizer):
|
|
|
248
242
|
return res.fun
|
|
249
243
|
|
|
250
244
|
|
|
245
|
+
class ScipyLeastSquaresOptimization(Optimizer):
|
|
246
|
+
"""Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
250
|
+
method (str | None, optional): _description_. Defaults to None.
|
|
251
|
+
tol (float | None, optional): _description_. Defaults to None.
|
|
252
|
+
callback (_type_, optional): _description_. Defaults to None.
|
|
253
|
+
options (_type_, optional): _description_. Defaults to None.
|
|
254
|
+
jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
|
|
255
|
+
"""
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
params,
|
|
259
|
+
method='trf',
|
|
260
|
+
jac='autograd',
|
|
261
|
+
bounds=(-np.inf, np.inf),
|
|
262
|
+
ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
|
|
263
|
+
f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
|
|
264
|
+
jac_sparsity=None, max_nfev=None, verbose=0
|
|
265
|
+
):
|
|
266
|
+
super().__init__(params, {})
|
|
267
|
+
kwargs = locals().copy()
|
|
268
|
+
del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
|
|
269
|
+
self._kwargs = kwargs
|
|
270
|
+
|
|
271
|
+
self.jac = jac
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _objective(self, x: np.ndarray, params: TensorList, closure):
|
|
275
|
+
# set params to x
|
|
276
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
277
|
+
|
|
278
|
+
# return the gradients
|
|
279
|
+
with torch.enable_grad(): self.value = closure()
|
|
280
|
+
jac = params.ensure_grad_().grad.to_vec()
|
|
281
|
+
return jac.numpy(force=True)
|
|
282
|
+
|
|
283
|
+
def _hess(self, x: np.ndarray, params: TensorList, closure):
|
|
284
|
+
params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
285
|
+
with torch.enable_grad():
|
|
286
|
+
value = closure(False)
|
|
287
|
+
_, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
|
|
288
|
+
return H.numpy(force=True)
|
|
289
|
+
|
|
290
|
+
@torch.no_grad
|
|
291
|
+
def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
|
|
292
|
+
params = self.get_params()
|
|
293
|
+
|
|
294
|
+
x0 = params.to_vec().detach().cpu().numpy()
|
|
295
|
+
|
|
296
|
+
if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
|
|
297
|
+
else: jac = self.jac
|
|
298
|
+
|
|
299
|
+
res = scipy.optimize.least_squares(
|
|
300
|
+
partial(self._objective, params = params, closure = closure),
|
|
301
|
+
x0 = x0,
|
|
302
|
+
jac=jac, # type:ignore
|
|
303
|
+
**self._kwargs
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
307
|
+
return res.fun
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
|
|
251
311
|
|
|
252
312
|
class ScipyDE(Optimizer):
|
|
253
313
|
"""Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
|
|
@@ -510,4 +570,3 @@ class ScipyBrute(Optimizer):
|
|
|
510
570
|
**self._kwargs
|
|
511
571
|
)
|
|
512
572
|
params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
|
|
513
|
-
return None
|