torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Module, Chainable, step
|
|
4
|
+
from ...utils import TensorList, vec_to_tensors
|
|
5
|
+
from ..second_order.newton import _newton_step, _get_H
|
|
6
|
+
|
|
7
|
+
def sg2_(
|
|
8
|
+
delta_g: torch.Tensor,
|
|
9
|
+
cd: torch.Tensor,
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
|
|
12
|
+
(or divide delta_g by two)."""
|
|
13
|
+
|
|
14
|
+
M = torch.outer(1.0 / cd, delta_g)
|
|
15
|
+
H_hat = 0.5 * (M + M.T)
|
|
16
|
+
|
|
17
|
+
return H_hat
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SG2(Module):
|
|
22
|
+
"""second-order stochastic gradient
|
|
23
|
+
|
|
24
|
+
SG2 with line search
|
|
25
|
+
```python
|
|
26
|
+
opt = tz.Modular(
|
|
27
|
+
model.parameters(),
|
|
28
|
+
tz.m.SG2(),
|
|
29
|
+
tz.m.Backtracking()
|
|
30
|
+
)
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
SG2 with trust region
|
|
34
|
+
```python
|
|
35
|
+
opt = tz.Modular(
|
|
36
|
+
model.parameters(),
|
|
37
|
+
tz.m.LevenbergMarquardt(tz.m.SG2()),
|
|
38
|
+
)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
n_samples: int = 1,
|
|
46
|
+
h: float = 1e-2,
|
|
47
|
+
beta: float | None = None,
|
|
48
|
+
damping: float = 0,
|
|
49
|
+
eigval_fn=None,
|
|
50
|
+
one_sided: bool = False, # one-sided hessian
|
|
51
|
+
use_lstsq: bool = True,
|
|
52
|
+
seed=None,
|
|
53
|
+
inner: Chainable | None = None,
|
|
54
|
+
):
|
|
55
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
|
|
56
|
+
super().__init__(defaults)
|
|
57
|
+
|
|
58
|
+
if inner is not None: self.set_child('inner', inner)
|
|
59
|
+
|
|
60
|
+
@torch.no_grad
|
|
61
|
+
def update(self, objective):
|
|
62
|
+
k = self.global_state.get('step', 0) + 1
|
|
63
|
+
self.global_state["step"] = k
|
|
64
|
+
|
|
65
|
+
params = TensorList(objective.params)
|
|
66
|
+
closure = objective.closure
|
|
67
|
+
if closure is None:
|
|
68
|
+
raise RuntimeError("closure is required for SG2")
|
|
69
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
70
|
+
|
|
71
|
+
h = self.get_settings(params, "h")
|
|
72
|
+
x_0 = params.clone()
|
|
73
|
+
n_samples = self.defaults["n_samples"]
|
|
74
|
+
H_hat = None
|
|
75
|
+
|
|
76
|
+
for i in range(n_samples):
|
|
77
|
+
# generate perturbation
|
|
78
|
+
cd = params.rademacher_like(generator=generator).mul_(h)
|
|
79
|
+
|
|
80
|
+
# one sided
|
|
81
|
+
if self.defaults["one_sided"]:
|
|
82
|
+
g_0 = TensorList(objective.get_grads())
|
|
83
|
+
params.add_(cd)
|
|
84
|
+
closure()
|
|
85
|
+
|
|
86
|
+
g_p = params.grad.fill_none_(params)
|
|
87
|
+
delta_g = (g_p - g_0) * 2
|
|
88
|
+
|
|
89
|
+
# two sided
|
|
90
|
+
else:
|
|
91
|
+
params.add_(cd)
|
|
92
|
+
closure()
|
|
93
|
+
g_p = params.grad.fill_none_(params)
|
|
94
|
+
|
|
95
|
+
params.copy_(x_0)
|
|
96
|
+
params.sub_(cd)
|
|
97
|
+
closure()
|
|
98
|
+
g_n = params.grad.fill_none_(params)
|
|
99
|
+
|
|
100
|
+
delta_g = g_p - g_n
|
|
101
|
+
|
|
102
|
+
# restore params
|
|
103
|
+
params.set_(x_0)
|
|
104
|
+
|
|
105
|
+
# compute H hat
|
|
106
|
+
H_i = sg2_(
|
|
107
|
+
delta_g = delta_g.to_vec(),
|
|
108
|
+
cd = cd.to_vec(),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if H_hat is None: H_hat = H_i
|
|
112
|
+
else: H_hat += H_i
|
|
113
|
+
|
|
114
|
+
assert H_hat is not None
|
|
115
|
+
if n_samples > 1: H_hat /= n_samples
|
|
116
|
+
|
|
117
|
+
# update H
|
|
118
|
+
H = self.global_state.get("H", None)
|
|
119
|
+
if H is None: H = H_hat
|
|
120
|
+
else:
|
|
121
|
+
beta = self.defaults["beta"]
|
|
122
|
+
if beta is None: beta = k / (k+1)
|
|
123
|
+
H.lerp_(H_hat, 1-beta)
|
|
124
|
+
|
|
125
|
+
self.global_state["H"] = H
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@torch.no_grad
|
|
129
|
+
def apply(self, objective):
|
|
130
|
+
dir = _newton_step(
|
|
131
|
+
objective=objective,
|
|
132
|
+
H = self.global_state["H"],
|
|
133
|
+
damping = self.defaults["damping"],
|
|
134
|
+
inner = self.children.get("inner", None),
|
|
135
|
+
H_tfm=None,
|
|
136
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
137
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
138
|
+
g_proj=None,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
objective.updates = vec_to_tensors(dir, objective.params)
|
|
142
|
+
return objective
|
|
143
|
+
|
|
144
|
+
def get_H(self,objective=...):
|
|
145
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# two sided
|
|
151
|
+
# we have g via x + d, x - d
|
|
152
|
+
# H via g(x + d), g(x - d)
|
|
153
|
+
# 1 is x, x+2d
|
|
154
|
+
# 2 is x, x-2d
|
|
155
|
+
# 5 evals in total
|
|
156
|
+
|
|
157
|
+
# one sided
|
|
158
|
+
# g via x, x + d
|
|
159
|
+
# 1 is x, x + d
|
|
160
|
+
# 2 is x, x - d
|
|
161
|
+
# 3 evals and can use two sided for g_0
|
|
162
|
+
|
|
163
|
+
class SPSA2(Module):
|
|
164
|
+
"""second-order SPSA
|
|
165
|
+
|
|
166
|
+
SPSA2 with line search
|
|
167
|
+
```python
|
|
168
|
+
opt = tz.Modular(
|
|
169
|
+
model.parameters(),
|
|
170
|
+
tz.m.SPSA2(),
|
|
171
|
+
tz.m.Backtracking()
|
|
172
|
+
)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
SPSA2 with trust region
|
|
176
|
+
```python
|
|
177
|
+
opt = tz.Modular(
|
|
178
|
+
model.parameters(),
|
|
179
|
+
tz.m.LevenbergMarquardt(tz.m.SPSA2()),
|
|
180
|
+
)
|
|
181
|
+
```
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
n_samples: int = 1,
|
|
187
|
+
h: float = 1e-2,
|
|
188
|
+
beta: float | None = None,
|
|
189
|
+
damping: float = 0,
|
|
190
|
+
eigval_fn=None,
|
|
191
|
+
use_lstsq: bool = True,
|
|
192
|
+
seed=None,
|
|
193
|
+
inner: Chainable | None = None,
|
|
194
|
+
):
|
|
195
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
|
|
196
|
+
super().__init__(defaults)
|
|
197
|
+
|
|
198
|
+
if inner is not None: self.set_child('inner', inner)
|
|
199
|
+
|
|
200
|
+
@torch.no_grad
|
|
201
|
+
def update(self, objective):
|
|
202
|
+
k = self.global_state.get('step', 0) + 1
|
|
203
|
+
self.global_state["step"] = k
|
|
204
|
+
|
|
205
|
+
params = TensorList(objective.params)
|
|
206
|
+
closure = objective.closure
|
|
207
|
+
if closure is None:
|
|
208
|
+
raise RuntimeError("closure is required for SPSA2")
|
|
209
|
+
|
|
210
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
211
|
+
|
|
212
|
+
h = self.get_settings(params, "h")
|
|
213
|
+
x_0 = params.clone()
|
|
214
|
+
n_samples = self.defaults["n_samples"]
|
|
215
|
+
H_hat = None
|
|
216
|
+
g_0 = None
|
|
217
|
+
|
|
218
|
+
for i in range(n_samples):
|
|
219
|
+
# perturbations for g and H
|
|
220
|
+
cd_g = params.rademacher_like(generator=generator).mul_(h)
|
|
221
|
+
cd_H = params.rademacher_like(generator=generator).mul_(h)
|
|
222
|
+
|
|
223
|
+
# evaluate 4 points
|
|
224
|
+
x_p = x_0 + cd_g
|
|
225
|
+
x_n = x_0 - cd_g
|
|
226
|
+
|
|
227
|
+
params.set_(x_p)
|
|
228
|
+
f_p = closure(False)
|
|
229
|
+
params.add_(cd_H)
|
|
230
|
+
f_pp = closure(False)
|
|
231
|
+
|
|
232
|
+
params.set_(x_n)
|
|
233
|
+
f_n = closure(False)
|
|
234
|
+
params.add_(cd_H)
|
|
235
|
+
f_np = closure(False)
|
|
236
|
+
|
|
237
|
+
g_p_vec = (f_pp - f_p) / cd_H
|
|
238
|
+
g_n_vec = (f_np - f_n) / cd_H
|
|
239
|
+
delta_g = g_p_vec - g_n_vec
|
|
240
|
+
|
|
241
|
+
# restore params
|
|
242
|
+
params.set_(x_0)
|
|
243
|
+
|
|
244
|
+
# compute grad
|
|
245
|
+
g_i = (f_p - f_n) / (2 * cd_g)
|
|
246
|
+
if g_0 is None: g_0 = g_i
|
|
247
|
+
else: g_0 += g_i
|
|
248
|
+
|
|
249
|
+
# compute H hat
|
|
250
|
+
H_i = sg2_(
|
|
251
|
+
delta_g = delta_g.to_vec().div_(2.0),
|
|
252
|
+
cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
|
|
253
|
+
)
|
|
254
|
+
if H_hat is None: H_hat = H_i
|
|
255
|
+
else: H_hat += H_i
|
|
256
|
+
|
|
257
|
+
assert g_0 is not None and H_hat is not None
|
|
258
|
+
if n_samples > 1:
|
|
259
|
+
g_0 /= n_samples
|
|
260
|
+
H_hat /= n_samples
|
|
261
|
+
|
|
262
|
+
# set grad to approximated grad
|
|
263
|
+
objective.grads = g_0
|
|
264
|
+
|
|
265
|
+
# update H
|
|
266
|
+
H = self.global_state.get("H", None)
|
|
267
|
+
if H is None: H = H_hat
|
|
268
|
+
else:
|
|
269
|
+
beta = self.defaults["beta"]
|
|
270
|
+
if beta is None: beta = k / (k+1)
|
|
271
|
+
H.lerp_(H_hat, 1-beta)
|
|
272
|
+
|
|
273
|
+
self.global_state["H"] = H
|
|
274
|
+
|
|
275
|
+
@torch.no_grad
|
|
276
|
+
def apply(self, objective):
|
|
277
|
+
dir = _newton_step(
|
|
278
|
+
objective=objective,
|
|
279
|
+
H = self.global_state["H"],
|
|
280
|
+
damping = self.defaults["damping"],
|
|
281
|
+
inner = self.children.get("inner", None),
|
|
282
|
+
H_tfm=None,
|
|
283
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
284
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
285
|
+
g_proj=None,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
objective.updates = vec_to_tensors(dir, objective.params)
|
|
289
|
+
return objective
|
|
290
|
+
|
|
291
|
+
def get_H(self,objective=...):
|
|
292
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
@@ -4,12 +4,14 @@ from typing import final, Literal, cast
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
7
|
+
from ...core import Chainable, Module, Objective
|
|
8
8
|
from ...utils import TensorList
|
|
9
9
|
from ..termination import TerminationCriteriaBase
|
|
10
10
|
|
|
11
11
|
def _reset_except_self(optimizer, var, self: Module):
|
|
12
|
-
for m in optimizer.unrolled_modules:
|
|
12
|
+
for m in optimizer.unrolled_modules:
|
|
13
|
+
if m is not self:
|
|
14
|
+
m.reset()
|
|
13
15
|
|
|
14
16
|
class RestartStrategyBase(Module, ABC):
|
|
15
17
|
"""Base class for restart strategies.
|
|
@@ -24,7 +26,7 @@ class RestartStrategyBase(Module, ABC):
|
|
|
24
26
|
self.set_child('modules', modules)
|
|
25
27
|
|
|
26
28
|
@abstractmethod
|
|
27
|
-
def should_reset(self, var:
|
|
29
|
+
def should_reset(self, var: Objective) -> bool:
|
|
28
30
|
"""returns whether reset should occur"""
|
|
29
31
|
|
|
30
32
|
def _reset_on_condition(self, var):
|
|
@@ -39,23 +41,23 @@ class RestartStrategyBase(Module, ABC):
|
|
|
39
41
|
return modules
|
|
40
42
|
|
|
41
43
|
@final
|
|
42
|
-
def update(self,
|
|
43
|
-
modules = self._reset_on_condition(
|
|
44
|
+
def update(self, objective):
|
|
45
|
+
modules = self._reset_on_condition(objective)
|
|
44
46
|
if modules is not None:
|
|
45
|
-
modules.update(
|
|
47
|
+
modules.update(objective)
|
|
46
48
|
|
|
47
49
|
@final
|
|
48
|
-
def apply(self,
|
|
50
|
+
def apply(self, objective):
|
|
49
51
|
# don't check here because it was check in `update`
|
|
50
52
|
modules = self.children.get('modules', None)
|
|
51
|
-
if modules is None: return
|
|
52
|
-
return modules.apply(
|
|
53
|
+
if modules is None: return objective
|
|
54
|
+
return modules.apply(objective.clone(clone_updates=False))
|
|
53
55
|
|
|
54
56
|
@final
|
|
55
|
-
def step(self,
|
|
56
|
-
modules = self._reset_on_condition(
|
|
57
|
-
if modules is None: return
|
|
58
|
-
return modules.step(
|
|
57
|
+
def step(self, objective):
|
|
58
|
+
modules = self._reset_on_condition(objective)
|
|
59
|
+
if modules is None: return objective
|
|
60
|
+
return modules.step(objective.clone(clone_updates=False))
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
|
|
@@ -170,7 +172,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
170
172
|
super().__init__(defaults, modules)
|
|
171
173
|
|
|
172
174
|
def should_reset(self, var):
|
|
173
|
-
g = TensorList(var.
|
|
175
|
+
g = TensorList(var.get_grads())
|
|
174
176
|
cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']
|
|
175
177
|
|
|
176
178
|
# -------------------------------- initialize -------------------------------- #
|
|
@@ -192,7 +194,7 @@ class PowellRestart(RestartStrategyBase):
|
|
|
192
194
|
|
|
193
195
|
# ------------------------------- 2nd condition ------------------------------ #
|
|
194
196
|
if (cond2 is not None) and (not reset):
|
|
195
|
-
d_g = TensorList(var.
|
|
197
|
+
d_g = TensorList(var.get_updates()).dot(g)
|
|
196
198
|
if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
|
|
197
199
|
reset = True
|
|
198
200
|
|
|
@@ -229,17 +231,17 @@ class BirginMartinezRestart(Module):
|
|
|
229
231
|
|
|
230
232
|
self.set_child("module", module)
|
|
231
233
|
|
|
232
|
-
def update(self,
|
|
234
|
+
def update(self, objective):
|
|
233
235
|
module = self.children['module']
|
|
234
|
-
module.update(
|
|
236
|
+
module.update(objective)
|
|
235
237
|
|
|
236
|
-
def apply(self,
|
|
238
|
+
def apply(self, objective):
|
|
237
239
|
module = self.children['module']
|
|
238
|
-
|
|
240
|
+
objective = module.apply(objective.clone(clone_updates=False))
|
|
239
241
|
|
|
240
242
|
cond = self.defaults['cond']
|
|
241
|
-
g = TensorList(
|
|
242
|
-
d = TensorList(
|
|
243
|
+
g = TensorList(objective.get_grads())
|
|
244
|
+
d = TensorList(objective.get_updates())
|
|
243
245
|
d_g = d.dot(g)
|
|
244
246
|
d_norm = d.global_vector_norm()
|
|
245
247
|
g_norm = g.global_vector_norm()
|
|
@@ -247,7 +249,7 @@ class BirginMartinezRestart(Module):
|
|
|
247
249
|
# d in our case is same direction as g so it has a minus sign
|
|
248
250
|
if -d_g > -cond * d_norm * g_norm:
|
|
249
251
|
module.reset()
|
|
250
|
-
|
|
251
|
-
return
|
|
252
|
+
objective.updates = g.clone()
|
|
253
|
+
return objective
|
|
252
254
|
|
|
253
|
-
return
|
|
255
|
+
return objective
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .ifn import InverseFreeNewton
|
|
2
|
+
from .inm import ImprovedNewton
|
|
3
|
+
from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
|
|
4
|
+
from .newton import Newton
|
|
2
5
|
from .newton_cg import NewtonCG, NewtonCGSteihaug
|
|
3
|
-
from .nystrom import
|
|
4
|
-
from .
|
|
6
|
+
from .nystrom import NystromPCG, NystromSketchAndSolve
|
|
7
|
+
from .rsn import SubspaceNewton
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable, Transform, HessianMethod
|
|
4
|
+
from ...utils import TensorList, vec_to_tensors
|
|
5
|
+
from ...linalg.linear_operator import DenseWithInverse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InverseFreeNewton(Transform):
|
|
9
|
+
"""Inverse-free newton's method
|
|
10
|
+
|
|
11
|
+
Reference
|
|
12
|
+
[Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
|
|
13
|
+
"""
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
update_freq: int = 1,
|
|
17
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
18
|
+
h: float = 1e-3,
|
|
19
|
+
inner: Chainable | None = None,
|
|
20
|
+
):
|
|
21
|
+
defaults = dict(hessian_method=hessian_method, h=h)
|
|
22
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
23
|
+
|
|
24
|
+
@torch.no_grad
|
|
25
|
+
def update_states(self, objective, states, settings):
|
|
26
|
+
fs = settings[0]
|
|
27
|
+
|
|
28
|
+
_, _, H = objective.hessian(
|
|
29
|
+
hessian_method=fs['hessian_method'],
|
|
30
|
+
h=fs['h'],
|
|
31
|
+
at_x0=True
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self.global_state["H"] = H
|
|
35
|
+
|
|
36
|
+
# inverse free part
|
|
37
|
+
if 'Y' not in self.global_state:
|
|
38
|
+
num = H.T
|
|
39
|
+
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
40
|
+
|
|
41
|
+
finfo = torch.finfo(H.dtype)
|
|
42
|
+
self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
43
|
+
|
|
44
|
+
else:
|
|
45
|
+
Y = self.global_state['Y']
|
|
46
|
+
I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
47
|
+
I2 -= H @ Y
|
|
48
|
+
self.global_state['Y'] = Y @ I2
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def apply_states(self, objective, states, settings):
|
|
52
|
+
Y = self.global_state["Y"]
|
|
53
|
+
g = torch.cat([t.ravel() for t in objective.get_updates()])
|
|
54
|
+
objective.updates = vec_to_tensors(Y@g, objective.params)
|
|
55
|
+
return objective
|
|
56
|
+
|
|
57
|
+
def get_H(self,objective=...):
|
|
58
|
+
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Transform, HessianMethod
|
|
6
|
+
from ...utils import TensorList, vec_to_tensors, unpack_states
|
|
7
|
+
from ..functional import safe_clip
|
|
8
|
+
from .newton import _get_H, _newton_step
|
|
9
|
+
|
|
10
|
+
@torch.no_grad
|
|
11
|
+
def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
12
|
+
|
|
13
|
+
yy = safe_clip(y.dot(y))
|
|
14
|
+
ss = safe_clip(s.dot(s))
|
|
15
|
+
|
|
16
|
+
term1 = y.dot(y - J@s) / yy
|
|
17
|
+
FbT = f.outer(s).mul_(term1 / ss)
|
|
18
|
+
|
|
19
|
+
P = FbT.add_(J)
|
|
20
|
+
return P
|
|
21
|
+
|
|
22
|
+
def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
|
|
23
|
+
if fn is None: return J
|
|
24
|
+
L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
|
|
25
|
+
return (Q * L.unsqueeze(-2)) @ Q.mH
|
|
26
|
+
|
|
27
|
+
class ImprovedNewton(Transform):
|
|
28
|
+
"""Improved Newton's Method (INM).
|
|
29
|
+
|
|
30
|
+
Reference:
|
|
31
|
+
[Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.](https://d-nb.info/1112813721/34)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
damping: float = 0,
|
|
37
|
+
use_lstsq: bool = False,
|
|
38
|
+
update_freq: int = 1,
|
|
39
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
40
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
41
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
42
|
+
h: float = 1e-3,
|
|
43
|
+
inner: Chainable | None = None,
|
|
44
|
+
):
|
|
45
|
+
defaults = locals().copy()
|
|
46
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
47
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner, )
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def update_states(self, objective, states, settings):
|
|
51
|
+
fs = settings[0]
|
|
52
|
+
|
|
53
|
+
_, f_list, J = objective.hessian(
|
|
54
|
+
hessian_method=fs['hessian_method'],
|
|
55
|
+
h=fs['h'],
|
|
56
|
+
at_x0=True
|
|
57
|
+
)
|
|
58
|
+
if f_list is None: f_list = objective.get_grads()
|
|
59
|
+
|
|
60
|
+
f = torch.cat([t.ravel() for t in f_list])
|
|
61
|
+
J = _eigval_fn(J, fs["eigval_fn"])
|
|
62
|
+
|
|
63
|
+
x_list = TensorList(objective.params)
|
|
64
|
+
f_list = TensorList(objective.get_grads())
|
|
65
|
+
x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)
|
|
66
|
+
|
|
67
|
+
# initialize on 1st step, do Newton step
|
|
68
|
+
if "P" not in self.global_state:
|
|
69
|
+
x_prev.copy_(x_list)
|
|
70
|
+
f_prev.copy_(f_list)
|
|
71
|
+
self.global_state["P"] = J
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
# INM update
|
|
75
|
+
s_list = x_list - x_prev
|
|
76
|
+
y_list = f_list - f_prev
|
|
77
|
+
x_prev.copy_(x_list)
|
|
78
|
+
f_prev.copy_(f_list)
|
|
79
|
+
|
|
80
|
+
self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@torch.no_grad
|
|
84
|
+
def apply_states(self, objective, states, settings):
|
|
85
|
+
fs = settings[0]
|
|
86
|
+
|
|
87
|
+
update = _newton_step(
|
|
88
|
+
objective = objective,
|
|
89
|
+
H = self.global_state["P"],
|
|
90
|
+
damping = fs["damping"],
|
|
91
|
+
H_tfm = fs["H_tfm"],
|
|
92
|
+
eigval_fn = None, # it is applied in `update`
|
|
93
|
+
use_lstsq = fs["use_lstsq"],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
objective.updates = vec_to_tensors(update, objective.params)
|
|
97
|
+
|
|
98
|
+
return objective
|
|
99
|
+
|
|
100
|
+
def get_H(self,objective=...):
|
|
101
|
+
return _get_H(self.global_state["P"], eigval_fn=None)
|