torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -10,8 +10,10 @@ from ...linalg.orthogonalize import orthogonalize as _orthogonalize, Orthogonali
|
|
|
10
10
|
def reverse_dims(t:torch.Tensor):
|
|
11
11
|
return t.permute(*reversed(range(t.ndim)))
|
|
12
12
|
|
|
13
|
-
def _is_at_least_2d(p: torch.Tensor):
|
|
14
|
-
if
|
|
13
|
+
def _is_at_least_2d(p: torch.Tensor, channel_first:bool):
|
|
14
|
+
if p.ndim < 2: return False
|
|
15
|
+
if channel_first and (p.size(0) > 1) and (p.size(1) > 1): return True
|
|
16
|
+
if (not channel_first) and (p.size(-2) > 1) and (p.size(-1) > 1): return True
|
|
15
17
|
return False
|
|
16
18
|
|
|
17
19
|
def _orthogonalize_format(
|
|
@@ -19,6 +21,7 @@ def _orthogonalize_format(
|
|
|
19
21
|
method: OrthogonalizeMethod,
|
|
20
22
|
channel_first: bool,
|
|
21
23
|
):
|
|
24
|
+
"""orthogonalize either 1st two dims if channel first or last two otherwise"""
|
|
22
25
|
if channel_first:
|
|
23
26
|
return reverse_dims(_orthogonalize(reverse_dims(tensor), method=method))
|
|
24
27
|
|
|
@@ -69,7 +72,7 @@ def orthogonalize_grads_(
|
|
|
69
72
|
are considered batch dimensions.
|
|
70
73
|
"""
|
|
71
74
|
for p in params:
|
|
72
|
-
if (p.grad is not None) and _is_at_least_2d(p.grad):
|
|
75
|
+
if (p.grad is not None) and _is_at_least_2d(p.grad, channel_first=channel_first):
|
|
73
76
|
X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
|
|
74
77
|
if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
|
|
75
78
|
p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
|
|
@@ -100,7 +103,7 @@ class Orthogonalize(TensorTransform):
|
|
|
100
103
|
|
|
101
104
|
standard Muon with Adam fallback
|
|
102
105
|
```py
|
|
103
|
-
opt = tz.
|
|
106
|
+
opt = tz.Optimizer(
|
|
104
107
|
model.head.parameters(),
|
|
105
108
|
tz.m.Split(
|
|
106
109
|
# apply muon only to 2D+ parameters
|
|
@@ -131,7 +134,7 @@ class Orthogonalize(TensorTransform):
|
|
|
131
134
|
|
|
132
135
|
if not orthogonalize: return tensor
|
|
133
136
|
|
|
134
|
-
if _is_at_least_2d(tensor):
|
|
137
|
+
if _is_at_least_2d(tensor, channel_first=channel_first):
|
|
135
138
|
|
|
136
139
|
X = _orthogonalize_format(tensor, method, channel_first=channel_first)
|
|
137
140
|
|
|
@@ -173,7 +176,7 @@ class MuonAdjustLR(Transform):
|
|
|
173
176
|
alphas = [s['alpha'] for s in settings]
|
|
174
177
|
channel_first = [s["channel_first=channel_first"] for s in settings]
|
|
175
178
|
tensors_alphas = [
|
|
176
|
-
(t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t)
|
|
179
|
+
(t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t, channel_first=cf)
|
|
177
180
|
]
|
|
178
181
|
tensors = [i[0] for i in tensors_alphas]
|
|
179
182
|
a = [i[1] for i in alphas]
|
|
@@ -4,7 +4,7 @@ from ...core import Transform
|
|
|
4
4
|
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
5
|
from ...utils import vec_to_tensors
|
|
6
6
|
from ...linalg import linear_operator
|
|
7
|
-
from .
|
|
7
|
+
from .ggt import ggt_update
|
|
8
8
|
|
|
9
9
|
class NaturalGradient(Transform):
|
|
10
10
|
"""Natural gradient approximated via empirical fisher information matrix.
|
|
@@ -41,7 +41,7 @@ class NaturalGradient(Transform):
|
|
|
41
41
|
y = torch.randn(64, 10)
|
|
42
42
|
|
|
43
43
|
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
44
|
-
opt = tz.
|
|
44
|
+
opt = tz.Optimizer(
|
|
45
45
|
model.parameters(),
|
|
46
46
|
tz.m.NaturalGradient(),
|
|
47
47
|
tz.m.LR(3e-2)
|
|
@@ -61,7 +61,7 @@ class NaturalGradient(Transform):
|
|
|
61
61
|
y = torch.randn(64, 10)
|
|
62
62
|
|
|
63
63
|
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
64
|
-
opt = tz.
|
|
64
|
+
opt = tz.Optimizer(
|
|
65
65
|
model.parameters(),
|
|
66
66
|
tz.m.NaturalGradient(),
|
|
67
67
|
tz.m.LR(3e-2)
|
|
@@ -84,7 +84,7 @@ class NaturalGradient(Transform):
|
|
|
84
84
|
return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
|
|
85
85
|
|
|
86
86
|
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
87
|
-
opt = tz.
|
|
87
|
+
opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
|
|
88
88
|
|
|
89
89
|
for iter in range(200):
|
|
90
90
|
losses = rosenbrock(X)
|
|
@@ -99,18 +99,24 @@ class NaturalGradient(Transform):
|
|
|
99
99
|
@torch.no_grad
|
|
100
100
|
def update_states(self, objective, states, settings):
|
|
101
101
|
params = objective.params
|
|
102
|
+
closure = objective.closure
|
|
102
103
|
fs = settings[0]
|
|
103
104
|
batched = fs['batched']
|
|
104
105
|
gn_grad = fs['gn_grad']
|
|
105
106
|
|
|
106
|
-
|
|
107
|
-
|
|
107
|
+
# compute per-sample losses
|
|
108
|
+
f = objective.loss
|
|
109
|
+
if f is None:
|
|
110
|
+
assert closure is not None
|
|
111
|
+
with torch.enable_grad():
|
|
112
|
+
f = objective.get_loss(backward=False) # n_out
|
|
113
|
+
assert isinstance(f, torch.Tensor)
|
|
108
114
|
|
|
115
|
+
# compute per-sample gradients
|
|
109
116
|
with torch.enable_grad():
|
|
110
|
-
f = objective.get_loss(backward=False) # n_out
|
|
111
|
-
assert isinstance(f, torch.Tensor)
|
|
112
117
|
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
113
118
|
|
|
119
|
+
# set scalar loss and it's grad to objective
|
|
114
120
|
objective.loss = f.sum()
|
|
115
121
|
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
|
|
116
122
|
|
|
@@ -123,8 +129,10 @@ class NaturalGradient(Transform):
|
|
|
123
129
|
objective.grads = vec_to_tensors(g, params)
|
|
124
130
|
|
|
125
131
|
# set closure to calculate scalar value for line searches etc
|
|
126
|
-
if
|
|
132
|
+
if closure is not None:
|
|
133
|
+
|
|
127
134
|
def ngd_closure(backward=True):
|
|
135
|
+
|
|
128
136
|
if backward:
|
|
129
137
|
objective.zero_grad()
|
|
130
138
|
with torch.enable_grad():
|
|
@@ -152,22 +160,31 @@ class NaturalGradient(Transform):
|
|
|
152
160
|
if sqrt:
|
|
153
161
|
# this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
|
|
154
162
|
# but it computes it through eigendecompotision
|
|
155
|
-
|
|
156
|
-
|
|
163
|
+
L, U = ggt_update(G.H, damping=reg, rdamping=1e-16, truncate=0, eig_tol=1e-12)
|
|
164
|
+
|
|
165
|
+
if U is None or L is None:
|
|
166
|
+
|
|
167
|
+
# fallback to element-wise
|
|
168
|
+
g = self.global_state["g"]
|
|
169
|
+
g /= G.square().mean(0).sqrt().add(reg)
|
|
170
|
+
objective.updates = vec_to_tensors(g, params)
|
|
171
|
+
return objective
|
|
157
172
|
|
|
158
|
-
|
|
173
|
+
# whiten
|
|
174
|
+
z = U.T @ self.global_state["g"]
|
|
175
|
+
v = (U * L.rsqrt()) @ z
|
|
159
176
|
objective.updates = vec_to_tensors(v, params)
|
|
160
177
|
return objective
|
|
161
178
|
|
|
162
179
|
# we need (G^T G)v = g
|
|
163
180
|
# where g = G^T
|
|
164
181
|
# so we need to solve (G^T G)v = G^T
|
|
165
|
-
|
|
182
|
+
GGt = G @ G.H # (n_samples, n_samples)
|
|
166
183
|
|
|
167
184
|
if reg != 0:
|
|
168
|
-
|
|
185
|
+
GGt.add_(torch.eye(GGt.size(0), device=GGt.device, dtype=GGt.dtype).mul_(reg))
|
|
169
186
|
|
|
170
|
-
z, _ = torch.linalg.solve_ex(
|
|
187
|
+
z, _ = torch.linalg.solve_ex(GGt, torch.ones_like(GGt[0])) # pylint:disable=not-callable
|
|
171
188
|
v = G.H @ z
|
|
172
189
|
|
|
173
190
|
objective.updates = vec_to_tensors(v, params)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .psgd import lift2single
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _initialize_lra_state_(tensor: torch.Tensor, state, setting):
|
|
10
|
+
n = tensor.numel()
|
|
11
|
+
rank = max(min(setting["rank"], n-1), 1)
|
|
12
|
+
dtype=tensor.dtype
|
|
13
|
+
device=tensor.device
|
|
14
|
+
|
|
15
|
+
U = torch.randn((n, rank), dtype=dtype, device=device)
|
|
16
|
+
U *= 0.1**0.5 / torch.linalg.vector_norm(U)
|
|
17
|
+
|
|
18
|
+
V = torch.randn((n, rank), dtype=dtype, device=device)
|
|
19
|
+
V *= 0.1**0.5 / torch.linalg.vector_norm(V)
|
|
20
|
+
|
|
21
|
+
if setting["init_scale"] is None:
|
|
22
|
+
# warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
|
|
23
|
+
d = None
|
|
24
|
+
else:
|
|
25
|
+
d = torch.ones(n, 1, dtype=dtype, device=device) * setting["init_scale"]
|
|
26
|
+
|
|
27
|
+
state["UVd"] = [U, V, d]
|
|
28
|
+
state["Luvd"] = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _wrap_with_no_backward(opt):
|
|
33
|
+
"""to use original psgd opts with visualbench"""
|
|
34
|
+
class _Wrapped:
|
|
35
|
+
def step(self, closure):
|
|
36
|
+
return opt.step(lambda: closure(False))
|
|
37
|
+
return _Wrapped()
|