torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torchzero.core import Chainable, Transform, HVPMethod
|
|
6
|
+
from torchzero.utils import NumberList, TensorList
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def matrix_nag_(
|
|
10
|
+
tensors_: TensorList,
|
|
11
|
+
s: TensorList,
|
|
12
|
+
Hvp_fn: Callable,
|
|
13
|
+
mu: float | NumberList,
|
|
14
|
+
):
|
|
15
|
+
s += tensors_
|
|
16
|
+
Hv = TensorList(Hvp_fn(s))
|
|
17
|
+
s -= Hv.mul_(mu)
|
|
18
|
+
return tensors_.add_(s)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MatrixNAG(Transform):
|
|
22
|
+
"""nesterov momentum version of matrix momentum. It seemed to work really well but adapting doesn't work,
|
|
23
|
+
I need to test more"""
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
mu=0.1,
|
|
27
|
+
hvp_method: HVPMethod = "autograd",
|
|
28
|
+
h: float = 1e-3,
|
|
29
|
+
adaptive:bool = False,
|
|
30
|
+
adapt_freq: int | None = None,
|
|
31
|
+
hvp_tfm: Chainable | None = None,
|
|
32
|
+
):
|
|
33
|
+
defaults = dict(mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
|
|
34
|
+
super().__init__(defaults)
|
|
35
|
+
|
|
36
|
+
if hvp_tfm is not None:
|
|
37
|
+
self.set_child('hvp_tfm', hvp_tfm)
|
|
38
|
+
|
|
39
|
+
def reset_for_online(self):
|
|
40
|
+
super().reset_for_online()
|
|
41
|
+
self.clear_state_keys('p_prev')
|
|
42
|
+
|
|
43
|
+
@torch.no_grad
|
|
44
|
+
def apply_states(self, objective, states, settings):
|
|
45
|
+
assert objective.closure is not None
|
|
46
|
+
step = self.global_state.get("step", 0)
|
|
47
|
+
self.global_state["step"] = step + 1
|
|
48
|
+
|
|
49
|
+
p = TensorList(objective.params)
|
|
50
|
+
g = TensorList(objective.get_grads(create_graph=self.defaults["hvp_method"] == "autograd"))
|
|
51
|
+
p_prev = self.get_state(p, "p_prev", init=p, cls=TensorList)
|
|
52
|
+
s = p - p_prev
|
|
53
|
+
p_prev.copy_(p)
|
|
54
|
+
|
|
55
|
+
# -------------------------------- adaptive mu ------------------------------- #
|
|
56
|
+
if self.defaults["adaptive"]:
|
|
57
|
+
|
|
58
|
+
if step == 1:
|
|
59
|
+
self.global_state["mu_mul"] = 0
|
|
60
|
+
|
|
61
|
+
else:
|
|
62
|
+
# ---------------------------- deterministic case ---------------------------- #
|
|
63
|
+
if self.defaults["adapt_freq"] is None:
|
|
64
|
+
g_prev = self.get_state(objective.params, "g_prev", cls=TensorList)
|
|
65
|
+
y = g - g_prev
|
|
66
|
+
g_prev.copy_(g)
|
|
67
|
+
|
|
68
|
+
denom = y.global_vector_norm()
|
|
69
|
+
denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
|
|
70
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
71
|
+
|
|
72
|
+
# -------------------------------- stochastic -------------------------------- #
|
|
73
|
+
else:
|
|
74
|
+
adapt_freq = self.defaults["adapt_freq"]
|
|
75
|
+
|
|
76
|
+
# we start on 1nd step, and want to adapt when we start, so use (step - 1)
|
|
77
|
+
if (step - 1) % adapt_freq == 0:
|
|
78
|
+
assert objective.closure is not None
|
|
79
|
+
p_cur = p.clone()
|
|
80
|
+
|
|
81
|
+
# move to previous params and evaluate p_prev with current mini-batch
|
|
82
|
+
p.copy_(self.get_state(objective.params, 'p_prev'))
|
|
83
|
+
with torch.enable_grad():
|
|
84
|
+
objective.closure()
|
|
85
|
+
g_prev = [t.grad if t.grad is not None else torch.zeros_like(t) for t in p]
|
|
86
|
+
y = g - g_prev
|
|
87
|
+
|
|
88
|
+
# move back to current params
|
|
89
|
+
p.copy_(p_cur)
|
|
90
|
+
|
|
91
|
+
denom = y.global_vector_norm()
|
|
92
|
+
denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
|
|
93
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
94
|
+
|
|
95
|
+
# -------------------------- matrix momentum update -------------------------- #
|
|
96
|
+
mu = self.get_settings(p, "mu", cls=NumberList)
|
|
97
|
+
if "mu_mul" in self.global_state:
|
|
98
|
+
mu = mu * self.global_state["mu_mul"]
|
|
99
|
+
|
|
100
|
+
# def Hvp_fn(v):
|
|
101
|
+
# Hv, _ = self.Hvp(
|
|
102
|
+
# v=v,
|
|
103
|
+
# at_x0=True,
|
|
104
|
+
# var=objective,
|
|
105
|
+
# rgrad=g,
|
|
106
|
+
# hvp_method=self.defaults["hvp_method"],
|
|
107
|
+
# h=self.defaults["h"],
|
|
108
|
+
# normalize=True,
|
|
109
|
+
# retain_grad=False,
|
|
110
|
+
# )
|
|
111
|
+
# return Hv
|
|
112
|
+
|
|
113
|
+
_, Hvp_fn = objective.list_Hvp_function(hvp_method=self.defaults["hvp_method"], h=self.defaults["h"], at_x0=True)
|
|
114
|
+
|
|
115
|
+
objective.updates = matrix_nag_(
|
|
116
|
+
tensors_=TensorList(objective.get_updates()),
|
|
117
|
+
s=s,
|
|
118
|
+
Hvp_fn=Hvp_fn,
|
|
119
|
+
mu=mu,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return objective
|
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable,
|
|
6
|
+
from ...core import Chainable, Optimizer, Module, step, HVPMethod
|
|
7
7
|
from ...utils import TensorList
|
|
8
8
|
from ..quasi_newton import LBFGS
|
|
9
9
|
|
|
@@ -12,7 +12,7 @@ class NewtonSolver(Module):
|
|
|
12
12
|
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
|
|
13
13
|
def __init__(
|
|
14
14
|
self,
|
|
15
|
-
solver: Callable[[list[torch.Tensor]], Any] = lambda p:
|
|
15
|
+
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Optimizer(p, LBFGS()),
|
|
16
16
|
maxiter=None,
|
|
17
17
|
maxiter1=None,
|
|
18
18
|
tol:float | None=1e-3,
|
|
@@ -7,22 +7,21 @@ from typing import Literal
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable,
|
|
10
|
+
from ...core import Chainable, Transform, step
|
|
11
11
|
from ...linalg.linear_operator import Dense
|
|
12
|
-
from ...utils import TensorList,
|
|
12
|
+
from ...utils import TensorList, vec_to_tensors_
|
|
13
13
|
from ...utils.derivatives import (
|
|
14
14
|
flatten_jacobian,
|
|
15
15
|
jacobian_wrt,
|
|
16
16
|
)
|
|
17
17
|
from ..second_order.newton import (
|
|
18
|
-
|
|
19
|
-
_eigh_solve,
|
|
18
|
+
_try_cholesky_solve,
|
|
20
19
|
_least_squares_solve,
|
|
21
|
-
|
|
20
|
+
_try_lu_solve,
|
|
22
21
|
)
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
class NewtonNewton(
|
|
24
|
+
class NewtonNewton(Transform):
|
|
26
25
|
"""Applies Newton-like preconditioning to Newton step.
|
|
27
26
|
|
|
28
27
|
This is a method that I thought of and then it worked. Here is how it works:
|
|
@@ -34,39 +33,32 @@ class NewtonNewton(Module):
|
|
|
34
33
|
3. Solve H2 x2 = x for x2.
|
|
35
34
|
|
|
36
35
|
4. Optionally, repeat (if order is higher than 3.)
|
|
37
|
-
|
|
38
|
-
Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
|
|
39
|
-
|
|
40
|
-
3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
|
|
41
|
-
this is if pytorch can vectorize hessian computation efficiently.
|
|
42
36
|
"""
|
|
43
37
|
def __init__(
|
|
44
38
|
self,
|
|
45
39
|
reg: float = 1e-6,
|
|
46
40
|
order: int = 3,
|
|
47
|
-
search_negative: bool = False,
|
|
48
41
|
vectorize: bool = True,
|
|
49
|
-
|
|
42
|
+
update_freq: int = 1,
|
|
43
|
+
inner: Chainable | None = None,
|
|
50
44
|
):
|
|
51
|
-
defaults = dict(order=order, reg=reg, vectorize=vectorize
|
|
52
|
-
super().__init__(defaults)
|
|
45
|
+
defaults = dict(order=order, reg=reg, vectorize=vectorize)
|
|
46
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
53
47
|
|
|
54
48
|
@torch.no_grad
|
|
55
|
-
def
|
|
49
|
+
def update_states(self, objective, states, settings):
|
|
50
|
+
fs = settings[0]
|
|
56
51
|
|
|
57
52
|
params = TensorList(objective.params)
|
|
58
53
|
closure = objective.closure
|
|
59
54
|
if closure is None: raise RuntimeError('NewtonNewton requires closure')
|
|
60
55
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
order = settings['order']
|
|
65
|
-
search_negative = settings['search_negative']
|
|
66
|
-
eigval_fn = settings['eigval_fn']
|
|
56
|
+
reg = fs['reg']
|
|
57
|
+
vectorize = fs['vectorize']
|
|
58
|
+
order = fs['order']
|
|
67
59
|
|
|
68
60
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
69
|
-
|
|
61
|
+
P = None
|
|
70
62
|
with torch.enable_grad():
|
|
71
63
|
loss = objective.loss = objective.loss_approx = closure(False)
|
|
72
64
|
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
@@ -81,28 +73,30 @@ class NewtonNewton(Module):
|
|
|
81
73
|
with torch.no_grad() if is_last else nullcontext():
|
|
82
74
|
H = flatten_jacobian(H_list)
|
|
83
75
|
if reg != 0: H = H + I * reg
|
|
84
|
-
|
|
76
|
+
if P is None: P = H
|
|
77
|
+
else: P = P @ H
|
|
85
78
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
x =
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
if x is None: x = _least_squares_solve(H, xp)
|
|
92
|
-
xp = x.squeeze()
|
|
79
|
+
if not is_last:
|
|
80
|
+
x = _try_cholesky_solve(H, xp)
|
|
81
|
+
if x is None: x = _try_lu_solve(H, xp)
|
|
82
|
+
if x is None: x = _least_squares_solve(H, xp)
|
|
83
|
+
xp = x.squeeze()
|
|
93
84
|
|
|
94
|
-
self.global_state["
|
|
95
|
-
self.global_state['xp'] = xp.nan_to_num_(0,0,0)
|
|
85
|
+
self.global_state["P"] = P
|
|
96
86
|
|
|
97
87
|
@torch.no_grad
|
|
98
|
-
def
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
88
|
+
def apply_states(self, objective, states, settings):
|
|
89
|
+
updates = objective.get_updates()
|
|
90
|
+
P = self.global_state['P']
|
|
91
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
92
|
+
|
|
93
|
+
sol = _try_cholesky_solve(P, b)
|
|
94
|
+
if sol is None: sol = _try_lu_solve(P, b)
|
|
95
|
+
if sol is None: sol = _least_squares_solve(P, b)
|
|
96
|
+
|
|
97
|
+
vec_to_tensors_(sol, updates)
|
|
102
98
|
return objective
|
|
103
99
|
|
|
104
100
|
@torch.no_grad
|
|
105
101
|
def get_H(self, objective=...):
|
|
106
|
-
|
|
107
|
-
if len(Hs) == 1: return Dense(Hs[0])
|
|
108
|
-
return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
|
|
102
|
+
return Dense(self.global_state["P"])
|
|
@@ -106,12 +106,12 @@ class FDM(GradApproximator):
|
|
|
106
106
|
plain FDM:
|
|
107
107
|
|
|
108
108
|
```python
|
|
109
|
-
fdm = tz.
|
|
109
|
+
fdm = tz.Optimizer(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
|
|
110
110
|
```
|
|
111
111
|
|
|
112
112
|
Any gradient-based method can use FDM-estimated gradients.
|
|
113
113
|
```python
|
|
114
|
-
fdm_ncg = tz.
|
|
114
|
+
fdm_ncg = tz.Optimizer(
|
|
115
115
|
model.parameters(),
|
|
116
116
|
tz.m.FDM(),
|
|
117
117
|
# set hvp_method to "forward" so that it
|
|
@@ -174,7 +174,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
174
174
|
|
|
175
175
|
SPSA is randomized FDM with rademacher distribution and central formula.
|
|
176
176
|
```py
|
|
177
|
-
spsa = tz.
|
|
177
|
+
spsa = tz.Optimizer(
|
|
178
178
|
model.parameters(),
|
|
179
179
|
tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
|
|
180
180
|
tz.m.LR(1e-2)
|
|
@@ -185,7 +185,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
185
185
|
|
|
186
186
|
RDSA is randomized FDM with usually gaussian distribution and central formula.
|
|
187
187
|
```
|
|
188
|
-
rdsa = tz.
|
|
188
|
+
rdsa = tz.Optimizer(
|
|
189
189
|
model.parameters(),
|
|
190
190
|
tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
|
|
191
191
|
tz.m.LR(1e-2)
|
|
@@ -196,7 +196,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
196
196
|
|
|
197
197
|
GS uses many gaussian samples with possibly a larger finite difference step size.
|
|
198
198
|
```
|
|
199
|
-
gs = tz.
|
|
199
|
+
gs = tz.Optimizer(
|
|
200
200
|
model.parameters(),
|
|
201
201
|
tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
|
|
202
202
|
tz.m.NewtonCG(hvp_method="forward"),
|
|
@@ -208,7 +208,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
208
208
|
|
|
209
209
|
Momentum might help by reducing the variance of the estimated gradients.
|
|
210
210
|
```
|
|
211
|
-
momentum_spsa = tz.
|
|
211
|
+
momentum_spsa = tz.Optimizer(
|
|
212
212
|
model.parameters(),
|
|
213
213
|
tz.m.RandomizedFDM(),
|
|
214
214
|
tz.m.HeavyBall(0.9),
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import Chainable,
|
|
3
|
+
from ...core import Chainable, Transform
|
|
4
4
|
from ...linalg import linear_operator
|
|
5
5
|
from ...utils import vec_to_tensors
|
|
6
6
|
from ...utils.derivatives import flatten_jacobian, jacobian_wrt
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class SumOfSquares(
|
|
9
|
+
class SumOfSquares(Transform):
|
|
10
10
|
"""Sets loss to be the sum of squares of values returned by the closure.
|
|
11
11
|
|
|
12
12
|
This is meant to be used to test least squares methods against ordinary minimization methods.
|
|
@@ -18,7 +18,7 @@ class SumOfSquares(Module):
|
|
|
18
18
|
super().__init__()
|
|
19
19
|
|
|
20
20
|
@torch.no_grad
|
|
21
|
-
def
|
|
21
|
+
def update_states(self, objective, states, settings):
|
|
22
22
|
closure = objective.closure
|
|
23
23
|
|
|
24
24
|
if closure is not None:
|
|
@@ -43,7 +43,11 @@ class SumOfSquares(Module):
|
|
|
43
43
|
if objective.loss_approx is not None:
|
|
44
44
|
objective.loss_approx = objective.loss_approx.pow(2).sum()
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
@torch.no_grad
|
|
47
|
+
def apply_states(self, objective, states, settings):
|
|
48
|
+
return objective
|
|
49
|
+
|
|
50
|
+
class GaussNewton(Transform):
|
|
47
51
|
"""Gauss-newton method.
|
|
48
52
|
|
|
49
53
|
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
@@ -57,6 +61,9 @@ class GaussNewton(Module):
|
|
|
57
61
|
|
|
58
62
|
Args:
|
|
59
63
|
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
64
|
+
update_freq (int, optional):
|
|
65
|
+
frequency of computing the jacobian. When jacobian is not computed, only residuals are computed and updated.
|
|
66
|
+
Defaults to 1.
|
|
60
67
|
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
61
68
|
|
|
62
69
|
Examples:
|
|
@@ -68,7 +75,7 @@ class GaussNewton(Module):
|
|
|
68
75
|
return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
|
|
69
76
|
|
|
70
77
|
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
71
|
-
opt = tz.
|
|
78
|
+
opt = tz.Optimizer([X], tz.m.GaussNewton(), tz.m.Backtracking())
|
|
72
79
|
|
|
73
80
|
# define the closure for line search
|
|
74
81
|
def closure(backward=True):
|
|
@@ -86,7 +93,7 @@ class GaussNewton(Module):
|
|
|
86
93
|
y = torch.randn(64, 10)
|
|
87
94
|
|
|
88
95
|
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
89
|
-
opt = tz.
|
|
96
|
+
opt = tz.Optimizer(
|
|
90
97
|
model.parameters(),
|
|
91
98
|
tz.m.TrustCG(tz.m.GaussNewton()),
|
|
92
99
|
)
|
|
@@ -101,33 +108,49 @@ class GaussNewton(Module):
|
|
|
101
108
|
print(f'{losses.mean() = }')
|
|
102
109
|
```
|
|
103
110
|
"""
|
|
104
|
-
def __init__(self, reg:float = 1e-8, batched:bool=True, inner: Chainable | None = None):
|
|
105
|
-
|
|
111
|
+
def __init__(self, reg:float = 1e-8, update_freq: int= 1, batched:bool=True, inner: Chainable | None = None):
|
|
112
|
+
defaults=dict(update_freq=update_freq,batched=batched, reg=reg)
|
|
113
|
+
super().__init__(defaults=defaults)
|
|
106
114
|
if inner is not None: self.set_child('inner', inner)
|
|
107
115
|
|
|
108
116
|
@torch.no_grad
|
|
109
|
-
def
|
|
117
|
+
def update_states(self, objective, states, settings):
|
|
118
|
+
fs = settings[0]
|
|
110
119
|
params = objective.params
|
|
111
|
-
batched = self.defaults['batched']
|
|
112
|
-
|
|
113
120
|
closure = objective.closure
|
|
114
|
-
|
|
121
|
+
batched = fs['batched']
|
|
122
|
+
update_freq = fs['update_freq']
|
|
123
|
+
|
|
124
|
+
# compute residuals
|
|
125
|
+
r = objective.loss
|
|
126
|
+
if r is None:
|
|
127
|
+
assert closure is not None
|
|
128
|
+
with torch.enable_grad():
|
|
129
|
+
r = objective.get_loss(backward=False) # n_residuals
|
|
130
|
+
assert isinstance(r, torch.Tensor)
|
|
131
|
+
|
|
132
|
+
# set sum of squares scalar loss and it's gradient to objective
|
|
133
|
+
objective.loss = r.pow(2).sum()
|
|
115
134
|
|
|
116
|
-
|
|
117
|
-
with torch.enable_grad():
|
|
118
|
-
r = objective.get_loss(backward=False) # nresiduals
|
|
119
|
-
assert isinstance(r, torch.Tensor)
|
|
120
|
-
J_list = jacobian_wrt([r.ravel()], params, batched=batched)
|
|
135
|
+
step = self.increment_counter("step", start=0)
|
|
121
136
|
|
|
122
|
-
|
|
137
|
+
if step % update_freq == 0:
|
|
138
|
+
|
|
139
|
+
# compute jacobian
|
|
140
|
+
with torch.enable_grad():
|
|
141
|
+
J_list = jacobian_wrt([r.ravel()], params, batched=batched)
|
|
142
|
+
|
|
143
|
+
J = self.global_state["J"] = flatten_jacobian(J_list) # (n_residuals, ndim)
|
|
144
|
+
|
|
145
|
+
else:
|
|
146
|
+
J = self.global_state["J"]
|
|
123
147
|
|
|
124
|
-
J = self.global_state["J"] = flatten_jacobian(J_list) # (nresiduals, ndim)
|
|
125
148
|
Jr = J.T @ r.detach() # (ndim)
|
|
126
149
|
|
|
127
150
|
# if there are more residuals, solve (J^T J)x = J^T r, so we need Jr
|
|
128
151
|
# otherwise solve (J J^T)z = r and set x = J^T z, so we need r
|
|
129
|
-
|
|
130
|
-
if
|
|
152
|
+
n_residuals, ndim = J.shape
|
|
153
|
+
if n_residuals >= ndim or "inner" in self.children:
|
|
131
154
|
self.global_state["Jr"] = Jr
|
|
132
155
|
|
|
133
156
|
else:
|
|
@@ -136,8 +159,9 @@ class GaussNewton(Module):
|
|
|
136
159
|
objective.grads = vec_to_tensors(Jr, objective.params)
|
|
137
160
|
|
|
138
161
|
# set closure to calculate sum of squares for line searches etc
|
|
139
|
-
if
|
|
162
|
+
if closure is not None:
|
|
140
163
|
def sos_closure(backward=True):
|
|
164
|
+
|
|
141
165
|
if backward:
|
|
142
166
|
objective.zero_grad()
|
|
143
167
|
with torch.enable_grad():
|
|
@@ -151,8 +175,9 @@ class GaussNewton(Module):
|
|
|
151
175
|
objective.closure = sos_closure
|
|
152
176
|
|
|
153
177
|
@torch.no_grad
|
|
154
|
-
def
|
|
155
|
-
|
|
178
|
+
def apply_states(self, objective, states, settings):
|
|
179
|
+
fs = settings[0]
|
|
180
|
+
reg = fs['reg']
|
|
156
181
|
|
|
157
182
|
J: torch.Tensor = self.global_state['J']
|
|
158
183
|
nresiduals, ndim = J.shape
|
|
@@ -170,39 +195,37 @@ class GaussNewton(Module):
|
|
|
170
195
|
Jr_list = objective.get_updates()
|
|
171
196
|
Jr = torch.cat([t.ravel() for t in Jr_list])
|
|
172
197
|
|
|
173
|
-
|
|
198
|
+
JtJ = J.T @ J # (ndim, ndim)
|
|
174
199
|
if reg != 0:
|
|
175
|
-
|
|
200
|
+
JtJ.add_(torch.eye(JtJ.size(0), device=JtJ.device, dtype=JtJ.dtype).mul_(reg))
|
|
176
201
|
|
|
177
202
|
if nresiduals >= ndim:
|
|
178
|
-
v, info = torch.linalg.solve_ex(
|
|
203
|
+
v, info = torch.linalg.solve_ex(JtJ, Jr) # pylint:disable=not-callable
|
|
179
204
|
else:
|
|
180
|
-
v = torch.linalg.lstsq(
|
|
205
|
+
v = torch.linalg.lstsq(JtJ, Jr).solution # pylint:disable=not-callable
|
|
181
206
|
|
|
182
207
|
objective.updates = vec_to_tensors(v, objective.params)
|
|
183
208
|
return objective
|
|
184
209
|
|
|
185
|
-
else:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
210
|
+
# else:
|
|
211
|
+
# solve (J J^T)z = r and set v = J^T z
|
|
212
|
+
# we need (J^T J)v = J^T r
|
|
213
|
+
# if z is solution to (G G^T)z = r, and v = J^T z
|
|
214
|
+
# then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
|
|
215
|
+
# therefore (J^T J)v = J^T r
|
|
216
|
+
# also this gives a minimum norm solution
|
|
192
217
|
|
|
193
|
-
|
|
218
|
+
r = self.global_state['r']
|
|
194
219
|
|
|
195
|
-
|
|
220
|
+
JJT = J @ J.T # (nresiduals, nresiduals)
|
|
221
|
+
if reg != 0:
|
|
222
|
+
JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
|
|
196
223
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
|
|
200
|
-
|
|
201
|
-
z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
|
|
202
|
-
v = J.T @ z
|
|
224
|
+
z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
|
|
225
|
+
v = J.T @ z
|
|
203
226
|
|
|
204
|
-
|
|
205
|
-
|
|
227
|
+
objective.updates = vec_to_tensors(v, objective.params)
|
|
228
|
+
return objective
|
|
206
229
|
|
|
207
230
|
def get_H(self, objective=...):
|
|
208
231
|
J = self.global_state['J']
|
|
@@ -77,7 +77,7 @@ class Backtracking(LineSearchBase):
|
|
|
77
77
|
Gradient descent with backtracking line search:
|
|
78
78
|
|
|
79
79
|
```python
|
|
80
|
-
opt = tz.
|
|
80
|
+
opt = tz.Optimizer(
|
|
81
81
|
model.parameters(),
|
|
82
82
|
tz.m.Backtracking()
|
|
83
83
|
)
|
|
@@ -85,7 +85,7 @@ class Backtracking(LineSearchBase):
|
|
|
85
85
|
|
|
86
86
|
L-BFGS with backtracking line search:
|
|
87
87
|
```python
|
|
88
|
-
opt = tz.
|
|
88
|
+
opt = tz.Optimizer(
|
|
89
89
|
model.parameters(),
|
|
90
90
|
tz.m.LBFGS(),
|
|
91
91
|
tz.m.Backtracking()
|
|
@@ -236,7 +236,7 @@ class StrongWolfe(LineSearchBase):
|
|
|
236
236
|
Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by ``a_init="first-order"``.
|
|
237
237
|
|
|
238
238
|
```python
|
|
239
|
-
opt = tz.
|
|
239
|
+
opt = tz.Optimizer(
|
|
240
240
|
model.parameters(),
|
|
241
241
|
tz.m.PolakRibiere(),
|
|
242
242
|
tz.m.StrongWolfe(c2=0.1, a_init="first-order")
|
|
@@ -245,7 +245,7 @@ class StrongWolfe(LineSearchBase):
|
|
|
245
245
|
|
|
246
246
|
LBFGS strong wolfe line search:
|
|
247
247
|
```python
|
|
248
|
-
opt = tz.
|
|
248
|
+
opt = tz.Optimizer(
|
|
249
249
|
model.parameters(),
|
|
250
250
|
tz.m.LBFGS(),
|
|
251
251
|
tz.m.StrongWolfe()
|
torchzero/modules/misc/escape.py
CHANGED
torchzero/modules/misc/misc.py
CHANGED
|
@@ -129,7 +129,7 @@ class Online(Module):
|
|
|
129
129
|
|
|
130
130
|
Online L-BFGS with Backtracking line search
|
|
131
131
|
```python
|
|
132
|
-
opt = tz.
|
|
132
|
+
opt = tz.Optimizer(
|
|
133
133
|
model.parameters(),
|
|
134
134
|
tz.m.Online(tz.m.LBFGS()),
|
|
135
135
|
tz.m.Backtracking()
|
|
@@ -138,19 +138,16 @@ class Online(Module):
|
|
|
138
138
|
|
|
139
139
|
Online L-BFGS trust region
|
|
140
140
|
```python
|
|
141
|
-
opt = tz.
|
|
141
|
+
opt = tz.Optimizer(
|
|
142
142
|
model.parameters(),
|
|
143
143
|
tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
|
|
144
144
|
)
|
|
145
145
|
```
|
|
146
146
|
|
|
147
147
|
"""
|
|
148
|
-
def __init__(self,
|
|
148
|
+
def __init__(self, module: Module,):
|
|
149
149
|
super().__init__()
|
|
150
|
-
|
|
151
|
-
raise RuntimeError("Online got empty list of modules. To make a module online, wrap it in tz.m.Online, e.g. `tz.m.Online(tz.m.LBFGS())`")
|
|
152
|
-
|
|
153
|
-
self.set_child('module', modules)
|
|
150
|
+
self.set_child('module', module)
|
|
154
151
|
|
|
155
152
|
@torch.no_grad
|
|
156
153
|
def update(self, objective):
|
|
@@ -23,7 +23,7 @@ class Dropout(Transform):
|
|
|
23
23
|
Gradient dropout.
|
|
24
24
|
|
|
25
25
|
```python
|
|
26
|
-
opt = tz.
|
|
26
|
+
opt = tz.Optimizer(
|
|
27
27
|
model.parameters(),
|
|
28
28
|
tz.m.Dropout(0.5),
|
|
29
29
|
tz.m.Adam(),
|
|
@@ -34,7 +34,7 @@ class Dropout(Transform):
|
|
|
34
34
|
Update dropout.
|
|
35
35
|
|
|
36
36
|
``python
|
|
37
|
-
opt = tz.
|
|
37
|
+
opt = tz.Optimizer(
|
|
38
38
|
model.parameters(),
|
|
39
39
|
tz.m.Adam(),
|
|
40
40
|
tz.m.Dropout(0.5),
|