torchzero 0.3.5__py3-none-any.whl → 0.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +1 -1
- tests/test_tensorlist.py +17 -17
- torchzero/core/preconditioner.py +11 -10
- torchzero/modules/experimental/__init__.py +3 -2
- torchzero/modules/experimental/absoap.py +8 -2
- torchzero/modules/experimental/adadam.py +1 -1
- torchzero/modules/experimental/adamY.py +1 -1
- torchzero/modules/experimental/adasoap.py +1 -1
- torchzero/modules/experimental/algebraic_newton.py +1 -1
- torchzero/modules/experimental/curveball.py +1 -1
- torchzero/modules/experimental/gradmin.py +1 -1
- torchzero/modules/experimental/newton_solver.py +88 -0
- torchzero/modules/experimental/{dsoap.py → soapy.py} +4 -4
- torchzero/modules/experimental/spectral.py +5 -3
- torchzero/modules/experimental/subspace_preconditioners.py +16 -9
- torchzero/modules/optimizers/soap.py +1 -2
- torchzero/modules/projections/projection.py +27 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
- torchzero/modules/quasi_newton/lbfgs.py +4 -3
- torchzero/modules/quasi_newton/lsr1.py +6 -3
- torchzero/modules/quasi_newton/quasi_newton.py +16 -17
- torchzero/modules/second_order/__init__.py +1 -1
- torchzero/modules/second_order/newton_cg.py +1 -1
- torchzero/utils/linalg/benchmark.py +20 -0
- torchzero/utils/linalg/solve.py +15 -14
- {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/METADATA +2 -2
- {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/RECORD +30 -28
- {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/WHEEL +0 -0
- {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/top_level.txt +0 -0
tests/test_opts.py
CHANGED
|
@@ -745,7 +745,7 @@ SSVM = Run(
|
|
|
745
745
|
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
746
746
|
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
747
747
|
needs_closure=True,
|
|
748
|
-
func='rosen', steps=50, loss=
|
|
748
|
+
func='rosen', steps=50, loss=0.02, merge_invariant=True,
|
|
749
749
|
sphere_steps=10, sphere_loss=0,
|
|
750
750
|
)
|
|
751
751
|
|
tests/test_tensorlist.py
CHANGED
|
@@ -1301,7 +1301,7 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
|
|
|
1301
1301
|
expected_tl = TensorList(expected_list)
|
|
1302
1302
|
assert isinstance(result, TensorList)
|
|
1303
1303
|
assert len(result) == len(expected_tl)
|
|
1304
|
-
assert_tl_allclose(result, expected_tl, atol=1e-
|
|
1304
|
+
assert_tl_allclose(result, expected_tl, atol=1e-3) # Use allclose due to potential float variations
|
|
1305
1305
|
|
|
1306
1306
|
# --- Grafting, Rescaling, Normalizing, Clipping ---
|
|
1307
1307
|
|
|
@@ -1381,8 +1381,8 @@ def test_rescale(simple_tl: TensorList, dim):
|
|
|
1381
1381
|
assert torch.allclose(rescaled_scalar.global_min(), torch.tensor(min_val))
|
|
1382
1382
|
assert torch.allclose(rescaled_scalar.global_max(), torch.tensor(max_val))
|
|
1383
1383
|
else:
|
|
1384
|
-
assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-
|
|
1385
|
-
assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-
|
|
1384
|
+
assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-3)
|
|
1385
|
+
assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-3)
|
|
1386
1386
|
|
|
1387
1387
|
|
|
1388
1388
|
# Rescale list
|
|
@@ -1402,8 +1402,8 @@ def test_rescale(simple_tl: TensorList, dim):
|
|
|
1402
1402
|
assert global_max_rescaled < avg_max + 1.0 # Loose check
|
|
1403
1403
|
|
|
1404
1404
|
else:
|
|
1405
|
-
assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-
|
|
1406
|
-
assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-
|
|
1405
|
+
assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-3)
|
|
1406
|
+
assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-3)
|
|
1407
1407
|
|
|
1408
1408
|
# Rescale to 01 helper
|
|
1409
1409
|
rescaled_01 = simple_tl.rescale_to_01(dim=dim, eps=eps)
|
|
@@ -1413,8 +1413,8 @@ def test_rescale(simple_tl: TensorList, dim):
|
|
|
1413
1413
|
assert torch.allclose(rescaled_01.global_min(), torch.tensor(0.0))
|
|
1414
1414
|
assert torch.allclose(rescaled_01.global_max(), torch.tensor(1.0))
|
|
1415
1415
|
else:
|
|
1416
|
-
assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-
|
|
1417
|
-
assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-
|
|
1416
|
+
assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-3)
|
|
1417
|
+
assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-3)
|
|
1418
1418
|
|
|
1419
1419
|
|
|
1420
1420
|
# Test inplace
|
|
@@ -1454,11 +1454,11 @@ def test_normalize(big_tl: TensorList, dim):
|
|
|
1454
1454
|
normalized_scalar_var = normalized_scalar.var(dim=dim if dim != 'global' else None)
|
|
1455
1455
|
|
|
1456
1456
|
if dim == 'global':
|
|
1457
|
-
assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-
|
|
1458
|
-
assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-
|
|
1457
|
+
assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-3)
|
|
1458
|
+
assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-3)
|
|
1459
1459
|
else:
|
|
1460
|
-
assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-
|
|
1461
|
-
assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-
|
|
1460
|
+
assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-3)
|
|
1461
|
+
assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-3)
|
|
1462
1462
|
|
|
1463
1463
|
# Normalize list mean/var
|
|
1464
1464
|
normalized_list = simple_tl.normalize(mean_list, var_list, dim=dim)
|
|
@@ -1476,19 +1476,19 @@ def test_normalize(big_tl: TensorList, dim):
|
|
|
1476
1476
|
# assert torch.allclose(global_mean_rescaled, torch.tensor(avg_mean), rtol=1e-1, atol=1e-1) # Loose check
|
|
1477
1477
|
# assert torch.allclose(global_var_rescaled, torch.tensor(avg_var), rtol=1e-1, atol=1e-1) # Loose check
|
|
1478
1478
|
else:
|
|
1479
|
-
assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-
|
|
1480
|
-
assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-
|
|
1479
|
+
assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-3)
|
|
1480
|
+
assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-3)
|
|
1481
1481
|
|
|
1482
1482
|
# Z-normalize helper
|
|
1483
1483
|
znorm = simple_tl.znormalize(dim=dim, eps=1e-10)
|
|
1484
1484
|
znorm_mean = znorm.mean(dim=dim if dim != 'global' else None)
|
|
1485
1485
|
znorm_var = znorm.var(dim=dim if dim != 'global' else None)
|
|
1486
1486
|
if dim == 'global':
|
|
1487
|
-
assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-
|
|
1488
|
-
assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-
|
|
1487
|
+
assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-3)
|
|
1488
|
+
assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-3)
|
|
1489
1489
|
else:
|
|
1490
|
-
assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-
|
|
1491
|
-
assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-
|
|
1490
|
+
assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-3)
|
|
1491
|
+
assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-3)
|
|
1492
1492
|
|
|
1493
1493
|
|
|
1494
1494
|
# Test inplace
|
torchzero/core/preconditioner.py
CHANGED
|
@@ -38,7 +38,7 @@ class Preconditioner(Transform):
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
41
|
-
step = self.global_state.get('
|
|
41
|
+
step = self.global_state.get('__step', 0)
|
|
42
42
|
states = [self.state[p] for p in params]
|
|
43
43
|
settings = [self.settings[p] for p in params]
|
|
44
44
|
global_settings = settings[0]
|
|
@@ -47,8 +47,10 @@ class Preconditioner(Transform):
|
|
|
47
47
|
scale_first = global_settings['__scale_first']
|
|
48
48
|
scale_factor = 0
|
|
49
49
|
if scale_first and step == 0:
|
|
50
|
-
# initial step size guess from pytorch LBFGS
|
|
51
|
-
|
|
50
|
+
# initial step size guess from pytorch LBFGS was too unstable
|
|
51
|
+
# I switched to norm
|
|
52
|
+
tensors = TensorList(tensors)
|
|
53
|
+
scale_factor = tensors.abs().global_mean().clip(min=1)
|
|
52
54
|
|
|
53
55
|
# update preconditioner
|
|
54
56
|
if step % update_freq == 0:
|
|
@@ -65,11 +67,11 @@ class Preconditioner(Transform):
|
|
|
65
67
|
if scale_first and step == 0:
|
|
66
68
|
torch._foreach_div_(tensors, scale_factor)
|
|
67
69
|
|
|
68
|
-
self.global_state['
|
|
70
|
+
self.global_state['__step'] = step + 1
|
|
69
71
|
return tensors
|
|
70
72
|
|
|
71
73
|
def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
72
|
-
step = self.global_state.get('
|
|
74
|
+
step = self.global_state.get('__step', 0)
|
|
73
75
|
tensors_vec = torch.cat([t.ravel() for t in tensors])
|
|
74
76
|
params_vec = torch.cat([p.ravel() for p in params])
|
|
75
77
|
grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
@@ -82,8 +84,8 @@ class Preconditioner(Transform):
|
|
|
82
84
|
scale_first = global_settings['__scale_first']
|
|
83
85
|
scale_factor = 0
|
|
84
86
|
if scale_first and step == 0:
|
|
85
|
-
# initial step size guess from pytorch LBFGS
|
|
86
|
-
scale_factor = tensors_vec.abs().
|
|
87
|
+
# initial step size guess from pytorch LBFGS was too unstable
|
|
88
|
+
scale_factor = tensors_vec.abs().mean().clip(min=1)
|
|
87
89
|
|
|
88
90
|
# update preconditioner
|
|
89
91
|
if step % update_freq == 0:
|
|
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
|
|
|
99
101
|
|
|
100
102
|
# scale initial step, when preconditioner might not have been applied
|
|
101
103
|
if scale_first and step == 0:
|
|
102
|
-
|
|
103
|
-
tensors_vec /= scale_factor
|
|
104
|
+
tensors_vec /= scale_factor
|
|
104
105
|
|
|
105
106
|
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
106
|
-
self.global_state['
|
|
107
|
+
self.global_state['__step'] = step + 1
|
|
107
108
|
return tensors
|
|
108
109
|
|
|
109
110
|
@torch.no_grad
|
|
@@ -3,7 +3,7 @@ from .adadam import Adadam
|
|
|
3
3
|
from .adamY import AdamY
|
|
4
4
|
from .adasoap import AdaSOAP
|
|
5
5
|
from .curveball import CurveBall
|
|
6
|
-
from .
|
|
6
|
+
from .soapy import SOAPY
|
|
7
7
|
from .gradmin import GradMin
|
|
8
8
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
9
9
|
from .spectral import SpectralPreconditioner
|
|
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
|
|
|
11
11
|
HistorySubspacePreconditioning,
|
|
12
12
|
RandomSubspacePreconditioning,
|
|
13
13
|
)
|
|
14
|
-
from .tropical_newton import TropicalNewton
|
|
14
|
+
from .tropical_newton import TropicalNewton
|
|
15
|
+
from .newton_solver import NewtonSolver
|
|
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
|
|
|
140
140
|
class ABSOAP(Transform):
|
|
141
141
|
"""SOAP but with two extra letters included in its name in order to improve converence
|
|
142
142
|
|
|
143
|
+
so what you can do is choose what goes into what ,and that is supposed to be good.
|
|
144
|
+
|
|
143
145
|
new args
|
|
144
146
|
|
|
145
147
|
scale by s whether to scale gradient differences by parameter differences
|
|
146
148
|
|
|
147
149
|
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
150
|
+
|
|
151
|
+
okay I changed these args into another ones
|
|
152
|
+
|
|
153
|
+
BASICALLY THIS IS FOR MY EXPERIMENTS
|
|
148
154
|
"""
|
|
149
155
|
def __init__(
|
|
150
156
|
self,
|
|
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
|
|
|
213
219
|
if 'g_prev' not in state:
|
|
214
220
|
state['p_prev'] = p.clone()
|
|
215
221
|
state['g_prev'] = t.clone()
|
|
216
|
-
updates.append(tensors[i].
|
|
222
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
217
223
|
continue
|
|
218
224
|
|
|
219
225
|
p_prev = state['p_prev']
|
|
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
|
|
|
285
291
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
286
292
|
|
|
287
293
|
state['step'] = 0
|
|
288
|
-
updates.append(tensors[i].
|
|
294
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
289
295
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
290
296
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
291
297
|
|
|
@@ -218,7 +218,7 @@ class AdaSOAP(Transform):
|
|
|
218
218
|
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
219
219
|
|
|
220
220
|
state['step'] = 0
|
|
221
|
-
updates.append(tensors[i].
|
|
221
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
222
222
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
223
223
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
224
224
|
|
|
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
class AlgebraicNewton(Module):
|
|
74
|
-
"""newton in other algebras, not
|
|
74
|
+
"""newton in other algebras, not that it works."""
|
|
75
75
|
def __init__(
|
|
76
76
|
self,
|
|
77
77
|
reg: float | None = None,
|
|
@@ -13,7 +13,7 @@ def curveball(
|
|
|
13
13
|
momentum: float | NumberList,
|
|
14
14
|
precond_lr: float | NumberList,
|
|
15
15
|
):
|
|
16
|
-
"""returns z_, clone it!!!"""
|
|
16
|
+
"""returns z_, clone it!!! (no just negate it)"""
|
|
17
17
|
delta = Hz + tensors
|
|
18
18
|
z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
|
|
19
19
|
return z_
|
|
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class GradMin(Reformulation):
|
|
17
|
-
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd.
|
|
17
|
+
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
20
|
loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
from typing import Any, Literal, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, apply, Modular
|
|
7
|
+
from ...utils import TensorList, as_tensorlist
|
|
8
|
+
from ...utils.derivatives import hvp
|
|
9
|
+
from ..quasi_newton import LBFGS
|
|
10
|
+
|
|
11
|
+
class NewtonSolver(Module):
|
|
12
|
+
"""Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
16
|
+
maxiter=None,
|
|
17
|
+
tol=1e-3,
|
|
18
|
+
reg: float = 0,
|
|
19
|
+
warm_start=True,
|
|
20
|
+
inner: Chainable | None = None,
|
|
21
|
+
):
|
|
22
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
|
|
23
|
+
super().__init__(defaults,)
|
|
24
|
+
|
|
25
|
+
if inner is not None:
|
|
26
|
+
self.set_child('inner', inner)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, vars):
|
|
30
|
+
params = TensorList(vars.params)
|
|
31
|
+
closure = vars.closure
|
|
32
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
33
|
+
|
|
34
|
+
settings = self.settings[params[0]]
|
|
35
|
+
solver_cls = settings['solver']
|
|
36
|
+
maxiter = settings['maxiter']
|
|
37
|
+
tol = settings['tol']
|
|
38
|
+
reg = settings['reg']
|
|
39
|
+
warm_start = settings['warm_start']
|
|
40
|
+
|
|
41
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
+
grad = vars.get_grad(create_graph=True)
|
|
43
|
+
|
|
44
|
+
def H_mm(x):
|
|
45
|
+
with torch.enable_grad():
|
|
46
|
+
Hvp = TensorList(hvp(params, grad, x, create_graph=True))
|
|
47
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
48
|
+
return Hvp
|
|
49
|
+
|
|
50
|
+
# -------------------------------- inner step -------------------------------- #
|
|
51
|
+
b = as_tensorlist(grad)
|
|
52
|
+
if 'inner' in self.children:
|
|
53
|
+
b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
|
|
54
|
+
|
|
55
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
56
|
+
x0 = None
|
|
57
|
+
if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
|
|
58
|
+
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
|
+
else: x = x0.clone().requires_grad_(True)
|
|
60
|
+
|
|
61
|
+
solver = solver_cls(x)
|
|
62
|
+
def lstsq_closure(backward=True):
|
|
63
|
+
Hx = H_mm(x)
|
|
64
|
+
loss = (Hx-b).pow(2).global_mean()
|
|
65
|
+
if backward:
|
|
66
|
+
solver.zero_grad()
|
|
67
|
+
loss.backward(inputs=x)
|
|
68
|
+
return loss
|
|
69
|
+
|
|
70
|
+
if maxiter is None: maxiter = b.global_numel()
|
|
71
|
+
loss = None
|
|
72
|
+
initial_loss = lstsq_closure(False)
|
|
73
|
+
if initial_loss > tol:
|
|
74
|
+
for i in range(maxiter):
|
|
75
|
+
loss = solver.step(lstsq_closure)
|
|
76
|
+
assert loss is not None
|
|
77
|
+
if min(loss, loss/initial_loss) < tol: break
|
|
78
|
+
|
|
79
|
+
print(f'{loss = }')
|
|
80
|
+
|
|
81
|
+
if warm_start:
|
|
82
|
+
assert x0 is not None
|
|
83
|
+
x0.copy_(x)
|
|
84
|
+
|
|
85
|
+
vars.update = x.detach()
|
|
86
|
+
return vars
|
|
87
|
+
|
|
88
|
+
|
|
@@ -3,7 +3,7 @@ from operator import itemgetter
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import Chainable, Transform, apply
|
|
6
|
-
from
|
|
6
|
+
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
7
|
|
|
8
8
|
@torch.no_grad
|
|
9
9
|
def update_soap_covariances_(
|
|
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
135
135
|
|
|
136
136
|
return final, exp_avg_sq
|
|
137
137
|
|
|
138
|
-
class
|
|
138
|
+
class SOAPY(Transform):
|
|
139
139
|
"""SOAP but uses scaled gradient differences
|
|
140
140
|
|
|
141
141
|
new args
|
|
@@ -195,7 +195,7 @@ class DSOAP(Transform):
|
|
|
195
195
|
if 'g_prev' not in state:
|
|
196
196
|
state['p_prev'] = p.clone()
|
|
197
197
|
state['g_prev'] = t.clone()
|
|
198
|
-
updates.append(tensors[i].
|
|
198
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
199
199
|
continue
|
|
200
200
|
|
|
201
201
|
p_prev = state['p_prev']
|
|
@@ -228,7 +228,7 @@ class DSOAP(Transform):
|
|
|
228
228
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
229
229
|
|
|
230
230
|
state['step'] = 0
|
|
231
|
-
updates.append(tensors[i].
|
|
231
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
232
232
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
233
233
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
234
234
|
|
|
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
194
194
|
order (int, optional):
|
|
195
195
|
whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
|
|
196
196
|
solver (str, optional): what to use for whitening. Defaults to 'svd'.
|
|
197
|
-
|
|
198
|
-
|
|
197
|
+
A_beta (float | None, optional):
|
|
198
|
+
beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
|
+
B_beta (float | None, optional):
|
|
200
|
+
beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
201
|
interval (int, optional): How often to update history. Defaults to 1 (every step).
|
|
200
202
|
concat_params (bool, optional):
|
|
201
203
|
whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
|
|
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
275
277
|
A = state.get('A', None)
|
|
276
278
|
if A is None:
|
|
277
279
|
# make a conservative step to avoid issues due to different GD scaling
|
|
278
|
-
return tensor.
|
|
280
|
+
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
279
281
|
|
|
280
282
|
B = state['B']
|
|
281
283
|
update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
|
|
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
|
|
|
38
38
|
return basis @ update_projected # d
|
|
39
39
|
|
|
40
40
|
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""full matrix rmsprop in random subspace"""
|
|
42
|
-
def __init__(self, k: int, beta: float | None = 0.99):
|
|
43
|
-
defaults = dict(k=k, beta=beta)
|
|
41
|
+
"""full matrix rmsprop in random slowly changing subspace"""
|
|
42
|
+
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
|
+
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
44
|
super().__init__(defaults, uses_grad=False)
|
|
45
45
|
|
|
46
|
+
if inner is not None: self.set_child('inner', inner)
|
|
47
|
+
|
|
46
48
|
def transform(self, tensors, params, grads, vars):
|
|
47
49
|
settings = self.settings[params[0]]
|
|
48
50
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
49
51
|
k = settings['k']
|
|
50
52
|
beta = settings['beta']
|
|
53
|
+
basis_beta = settings['basis_beta']
|
|
51
54
|
|
|
52
55
|
if 'basis' not in self.global_state:
|
|
53
56
|
self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
56
59
|
basis = self.global_state['basis']
|
|
57
60
|
accumulator = self.global_state['accumulator']
|
|
58
61
|
|
|
62
|
+
if basis_beta is not None:
|
|
63
|
+
basis.lerp_(torch.randn_like(basis), 1-basis_beta)
|
|
64
|
+
|
|
59
65
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
66
|
+
|
|
67
|
+
if 'inner' in self.children:
|
|
68
|
+
tensors = apply(self.children['inner'], tensors, params, grads, vars)
|
|
69
|
+
g = torch.cat([t.view(-1) for t in tensors])
|
|
70
|
+
|
|
60
71
|
try:
|
|
61
72
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
62
73
|
except torch.linalg.LinAlgError:
|
|
63
|
-
|
|
64
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
65
|
-
preconditioned = g / g.abs().sum()
|
|
74
|
+
preconditioned = g.clip(-0.1, 0.1)
|
|
66
75
|
vec_to_tensors_(preconditioned, tensors)
|
|
67
76
|
|
|
68
77
|
return tensors
|
|
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
119
128
|
try:
|
|
120
129
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
121
130
|
except torch.linalg.LinAlgError:
|
|
122
|
-
|
|
123
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
124
|
-
preconditioned = g / g.abs().sum()
|
|
131
|
+
preconditioned = g.clip(-0.1,0.1)
|
|
125
132
|
vec_to_tensors_(preconditioned, tensors)
|
|
126
133
|
|
|
127
134
|
return tensors
|
|
@@ -222,8 +222,7 @@ class SOAP(Transform):
|
|
|
222
222
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
223
223
|
|
|
224
224
|
state['step'] = 0
|
|
225
|
-
updates.append(tensors[i].
|
|
226
|
-
# updates.append(tensors[i] / tensors[i].abs().sum())
|
|
225
|
+
updates.append(tensors[i].clip(-0.1, 0.1))
|
|
227
226
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
228
227
|
# I use scaled update instead as to not mess up with next modules.
|
|
229
228
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from functools import partial
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import Iterable
|
|
4
5
|
from typing import Any, Literal
|
|
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
|
|
|
33
34
|
|
|
34
35
|
return projected_closure
|
|
35
36
|
|
|
37
|
+
def _projected_get_grad_override(
|
|
38
|
+
retain_graph: bool | None = None,
|
|
39
|
+
create_graph: bool = False,
|
|
40
|
+
projection: Any = ...,
|
|
41
|
+
unprojected_vars: Any = ...,
|
|
42
|
+
self: Any = ...,
|
|
43
|
+
):
|
|
44
|
+
assert isinstance(projection, Projection)
|
|
45
|
+
assert isinstance(unprojected_vars, Vars)
|
|
46
|
+
assert isinstance(self, Vars)
|
|
47
|
+
|
|
48
|
+
if self.grad is not None: return self.grad
|
|
49
|
+
grads = unprojected_vars.get_grad(retain_graph, create_graph)
|
|
50
|
+
projected_grads = list(projection.project(grads, self, current='grads'))
|
|
51
|
+
self.grad = projected_grads
|
|
52
|
+
for p, g in zip(self.params, projected_grads):
|
|
53
|
+
p.grad = g
|
|
54
|
+
return self.grad
|
|
55
|
+
|
|
36
56
|
|
|
37
57
|
class Projection(Module, ABC):
|
|
38
58
|
"""
|
|
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
|
|
|
137
157
|
|
|
138
158
|
# step
|
|
139
159
|
projected_vars.params = self._projected_params
|
|
160
|
+
projected_vars.get_grad = partial(
|
|
161
|
+
_projected_get_grad_override,
|
|
162
|
+
projection=self,
|
|
163
|
+
unprojected_vars=vars,
|
|
164
|
+
self=projected_vars,
|
|
165
|
+
)
|
|
140
166
|
projected_vars = self.children['modules'].step(projected_vars)
|
|
141
167
|
|
|
142
168
|
# empty fake params storage
|
|
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
|
|
|
149
175
|
unprojected_vars = projected_vars.clone(clone_update=False)
|
|
150
176
|
unprojected_vars.closure = vars.closure
|
|
151
177
|
unprojected_vars.params = vars.params
|
|
152
|
-
|
|
178
|
+
unprojected_vars.grad = vars.grad
|
|
153
179
|
|
|
154
180
|
if self._project_update:
|
|
155
181
|
assert projected_vars.update is not None
|
|
@@ -37,10 +37,11 @@ def lbfgs(
|
|
|
37
37
|
z_tfm: Any,
|
|
38
38
|
):
|
|
39
39
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
40
|
-
# dir = params.grad.sign() # may work fine
|
|
41
40
|
|
|
42
|
-
# initial step size guess
|
|
43
|
-
|
|
41
|
+
# initial step size guess modified from pytorch L-BFGS
|
|
42
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
43
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
44
|
+
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
44
45
|
|
|
45
46
|
else:
|
|
46
47
|
# 1st loop
|
|
@@ -36,10 +36,11 @@ def lbfgs(
|
|
|
36
36
|
step: int,
|
|
37
37
|
):
|
|
38
38
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
39
|
-
# dir = params.grad.sign() # may work fine
|
|
40
39
|
|
|
41
|
-
# initial step size guess
|
|
42
|
-
|
|
40
|
+
# initial step size guess modified from pytorch L-BFGS
|
|
41
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
42
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
43
|
+
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
43
44
|
|
|
44
45
|
else:
|
|
45
46
|
# 1st loop
|
|
@@ -17,8 +17,9 @@ def lsr1_(
|
|
|
17
17
|
):
|
|
18
18
|
if step == 0 or not s_history:
|
|
19
19
|
# initial step size guess from pytorch
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
21
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
22
|
+
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
22
23
|
|
|
23
24
|
m = len(s_history)
|
|
24
25
|
|
|
@@ -64,7 +65,9 @@ def lsr1_(
|
|
|
64
65
|
Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
|
|
65
66
|
|
|
66
67
|
if scale_second and step == 1:
|
|
67
|
-
|
|
68
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
69
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
70
|
+
Hx.mul_(min(1.0, scale)) # pyright:ignore[reportArgumentType]
|
|
68
71
|
return Hx
|
|
69
72
|
|
|
70
73
|
|
|
@@ -68,6 +68,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
68
68
|
M_key = 'H' if inverse else 'B'
|
|
69
69
|
M = state.get(M_key, None)
|
|
70
70
|
step = state.get('step', 0)
|
|
71
|
+
state['step'] = step + 1
|
|
71
72
|
init_scale = settings['init_scale']
|
|
72
73
|
tol = settings['tol']
|
|
73
74
|
tol_reset = settings['tol_reset']
|
|
@@ -91,13 +92,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
91
92
|
state['p_prev'].copy_(p)
|
|
92
93
|
state['g_prev'].copy_(g)
|
|
93
94
|
|
|
94
|
-
|
|
95
|
-
if reset_interval is not None and step % reset_interval == 0:
|
|
95
|
+
if reset_interval is not None and step != 0 and step % reset_interval == 0:
|
|
96
96
|
self._reset_M_(M, s, y, inverse, init_scale)
|
|
97
97
|
return
|
|
98
98
|
|
|
99
99
|
# tolerance on gradient difference to avoid exploding after converging
|
|
100
|
-
|
|
100
|
+
elif y.abs().max() <= tol:
|
|
101
101
|
# reset history
|
|
102
102
|
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
|
|
103
103
|
return
|
|
@@ -119,11 +119,10 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
119
119
|
|
|
120
120
|
@torch.no_grad
|
|
121
121
|
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
122
|
-
step = state
|
|
122
|
+
step = state.get('step', 0)
|
|
123
123
|
|
|
124
124
|
if settings['scale_second'] and step == 2:
|
|
125
|
-
|
|
126
|
-
if s < settings['tol']: tensor = tensor/s
|
|
125
|
+
tensor = tensor / tensor.abs().mean().clip(min=1)
|
|
127
126
|
|
|
128
127
|
inverse = settings['inverse']
|
|
129
128
|
if inverse:
|
|
@@ -135,7 +134,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
135
134
|
return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
|
|
136
135
|
|
|
137
136
|
# to avoid typing all arguments for each method
|
|
138
|
-
class
|
|
137
|
+
class HUpdateStrategy(HessianUpdateStrategy):
|
|
139
138
|
def __init__(
|
|
140
139
|
self,
|
|
141
140
|
init_scale: float | Literal["auto"] = "auto",
|
|
@@ -174,7 +173,7 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
174
173
|
H += term1.sub_(term2)
|
|
175
174
|
return H
|
|
176
175
|
|
|
177
|
-
class BFGS(
|
|
176
|
+
class BFGS(HUpdateStrategy):
|
|
178
177
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
179
178
|
return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
180
179
|
|
|
@@ -193,7 +192,7 @@ def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
|
193
192
|
H += torch.outer(z, z).div_(denom)
|
|
194
193
|
return H
|
|
195
194
|
|
|
196
|
-
class SR1(
|
|
195
|
+
class SR1(HUpdateStrategy):
|
|
197
196
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
198
197
|
return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
199
198
|
|
|
@@ -213,7 +212,7 @@ def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
213
212
|
H += term1.sub_(term2)
|
|
214
213
|
return H
|
|
215
214
|
|
|
216
|
-
class DFP(
|
|
215
|
+
class DFP(HUpdateStrategy):
|
|
217
216
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
218
217
|
return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
219
218
|
|
|
@@ -254,19 +253,19 @@ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
254
253
|
H -= num/denom
|
|
255
254
|
return H
|
|
256
255
|
|
|
257
|
-
class BroydenGood(
|
|
256
|
+
class BroydenGood(HUpdateStrategy):
|
|
258
257
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
259
258
|
return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
260
259
|
|
|
261
|
-
class BroydenBad(
|
|
260
|
+
class BroydenBad(HUpdateStrategy):
|
|
262
261
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
263
262
|
return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
264
263
|
|
|
265
|
-
class Greenstadt1(
|
|
264
|
+
class Greenstadt1(HUpdateStrategy):
|
|
266
265
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
267
266
|
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
|
|
268
267
|
|
|
269
|
-
class Greenstadt2(
|
|
268
|
+
class Greenstadt2(HUpdateStrategy):
|
|
270
269
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
271
270
|
return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
272
271
|
|
|
@@ -287,7 +286,7 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
|
|
|
287
286
|
H[:, j] += num.squeeze() / denom
|
|
288
287
|
return H
|
|
289
288
|
|
|
290
|
-
class ColumnUpdatingMethod(
|
|
289
|
+
class ColumnUpdatingMethod(HUpdateStrategy):
|
|
291
290
|
"""Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
|
|
292
291
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
293
292
|
return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
@@ -307,7 +306,7 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
|
|
|
307
306
|
H -= num/denom
|
|
308
307
|
return H, R
|
|
309
308
|
|
|
310
|
-
class ThomasOptimalMethod(
|
|
309
|
+
class ThomasOptimalMethod(HUpdateStrategy):
|
|
311
310
|
"""Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
|
|
312
311
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
313
312
|
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
@@ -364,7 +363,7 @@ def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
364
363
|
H += num.div_(sy)
|
|
365
364
|
return H
|
|
366
365
|
|
|
367
|
-
class Pearson2(
|
|
366
|
+
class Pearson2(HUpdateStrategy):
|
|
368
367
|
"""finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
|
|
369
368
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
370
369
|
return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def benchmark_solver(
|
|
7
|
+
A: torch.Tensor | Callable[[torch.Tensor], torch.Tensor],
|
|
8
|
+
b: torch.Tensor,
|
|
9
|
+
solver: Callable[[Callable[[torch.Tensor], torch.Tensor], torch.Tensor]]
|
|
10
|
+
):
|
|
11
|
+
residuals = []
|
|
12
|
+
def A_mm(x):
|
|
13
|
+
if callable(A): Ax = A(x)
|
|
14
|
+
else: Ax = A@x
|
|
15
|
+
residuals.append(torch.linalg.vector_norm(Ax-b)) # pylint:disable=not-callable
|
|
16
|
+
return Ax
|
|
17
|
+
|
|
18
|
+
solver(A_mm, b)
|
|
19
|
+
return residuals
|
|
20
|
+
|
torchzero/utils/linalg/solve.py
CHANGED
|
@@ -8,27 +8,27 @@ from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_nume
|
|
|
8
8
|
def cg(
|
|
9
9
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
10
10
|
b: torch.Tensor,
|
|
11
|
-
x0_: torch.Tensor | None,
|
|
12
|
-
tol: float | None,
|
|
13
|
-
maxiter: int | None,
|
|
11
|
+
x0_: torch.Tensor | None = None,
|
|
12
|
+
tol: float | None = 1e-4,
|
|
13
|
+
maxiter: int | None = None,
|
|
14
14
|
reg: float = 0,
|
|
15
15
|
) -> torch.Tensor: ...
|
|
16
16
|
@overload
|
|
17
17
|
def cg(
|
|
18
18
|
A_mm: Callable[[TensorList], TensorList],
|
|
19
19
|
b: TensorList,
|
|
20
|
-
x0_: TensorList | None,
|
|
21
|
-
tol: float | None,
|
|
22
|
-
maxiter: int | None,
|
|
20
|
+
x0_: TensorList | None = None,
|
|
21
|
+
tol: float | None = 1e-4,
|
|
22
|
+
maxiter: int | None = None,
|
|
23
23
|
reg: float | list[float] | tuple[float] = 0,
|
|
24
24
|
) -> TensorList: ...
|
|
25
25
|
|
|
26
26
|
def cg(
|
|
27
27
|
A_mm: Callable,
|
|
28
28
|
b: torch.Tensor | TensorList,
|
|
29
|
-
x0_: torch.Tensor | TensorList | None,
|
|
30
|
-
tol: float | None,
|
|
31
|
-
maxiter: int | None,
|
|
29
|
+
x0_: torch.Tensor | TensorList | None = None,
|
|
30
|
+
tol: float | None = 1e-4,
|
|
31
|
+
maxiter: int | None = None,
|
|
32
32
|
reg: float | list[float] | tuple[float] = 0,
|
|
33
33
|
):
|
|
34
34
|
def A_mm_reg(x): # A_mm with regularization
|
|
@@ -90,7 +90,7 @@ def nystrom_sketch_and_solve(
|
|
|
90
90
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
91
91
|
b: torch.Tensor,
|
|
92
92
|
rank: int,
|
|
93
|
-
reg: float,
|
|
93
|
+
reg: float = 1e-3,
|
|
94
94
|
generator=None,
|
|
95
95
|
) -> torch.Tensor:
|
|
96
96
|
U, lambd = nystrom_approximation(
|
|
@@ -116,10 +116,10 @@ def nystrom_pcg(
|
|
|
116
116
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
117
117
|
b: torch.Tensor,
|
|
118
118
|
sketch_size: int,
|
|
119
|
-
reg: float,
|
|
120
|
-
x0_: torch.Tensor | None,
|
|
121
|
-
tol: float | None,
|
|
122
|
-
maxiter: int | None,
|
|
119
|
+
reg: float = 1e-6,
|
|
120
|
+
x0_: torch.Tensor | None = None,
|
|
121
|
+
tol: float | None = 1e-4,
|
|
122
|
+
maxiter: int | None = None,
|
|
123
123
|
generator=None,
|
|
124
124
|
) -> torch.Tensor:
|
|
125
125
|
U, lambd = nystrom_approximation(
|
|
@@ -166,3 +166,4 @@ def nystrom_pcg(
|
|
|
166
166
|
z = P_inv @ residual
|
|
167
167
|
beta = residual.dot(z) / rz
|
|
168
168
|
p = z + p*beta
|
|
169
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.8
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -156,7 +156,7 @@ for epoch in range(100):
|
|
|
156
156
|
* `Newton`: Classic Newton's method.
|
|
157
157
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
158
158
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning (
|
|
159
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
160
160
|
|
|
161
161
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
162
|
* `LBFGS`: Limited-memory BFGS.
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
docs/source/conf.py,sha256=jd80ZT2IdCx7nlQrpOTJL8UhGBNm6KYyXlpp0jmRiAw,1849
|
|
2
2
|
tests/test_identical.py,sha256=NZ7A8Rm1U9Q16d-cG2G_wccpPtNALyoKYJt9qMownMc,11568
|
|
3
3
|
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
4
|
-
tests/test_opts.py,sha256=
|
|
5
|
-
tests/test_tensorlist.py,sha256=
|
|
4
|
+
tests/test_opts.py,sha256=XfpDaVwOC2VuG700BXWAFWiemeVW0ucLG74yfns9mB8,40849
|
|
5
|
+
tests/test_tensorlist.py,sha256=VWX9wYdfkG-0Y8I0wWPp56ZJM0mBNPvS_SC3irmcYcs,72427
|
|
6
6
|
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
7
7
|
tests/test_vars.py,sha256=3p9dsHk7SJpMd-WRD0ziBNq5FEHRBJGSxbMLD8ES4J0,6815
|
|
8
8
|
torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
|
|
9
9
|
torchzero/core/__init__.py,sha256=2JRyeGZprTexAeEPQOIl9fLFGBwzvya-AwKyt7XAmGQ,210
|
|
10
10
|
torchzero/core/module.py,sha256=Razw3c71Kfegznm0vQxsii1KuTUCPBC9UGyq2v-KX4M,27568
|
|
11
|
-
torchzero/core/preconditioner.py,sha256=
|
|
11
|
+
torchzero/core/preconditioner.py,sha256=R1IGk7Tbea5wSkazpnXwusjvBxzJHzEWgCtR_nEz2w4,6258
|
|
12
12
|
torchzero/core/transform.py,sha256=ajNJcX45ds-_lc5CqxgLfEFGil6_BYLerB0WvoTi8rM,10303
|
|
13
13
|
torchzero/modules/__init__.py,sha256=BDeyuSd2s1WFUUXIo3tGTNp4aYp4A2B94cydpPW24nY,332
|
|
14
14
|
torchzero/modules/functional.py,sha256=HXNzmPe7LsPadryEm7zrcEKqGej16QDwSgBkbEvggFM,6492
|
|
@@ -16,18 +16,19 @@ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLB
|
|
|
16
16
|
torchzero/modules/clipping/clipping.py,sha256=I-5utyrqdKtF5yaH-9m2F3UqdfpPmA2bSSFUAZ_d60Q,12544
|
|
17
17
|
torchzero/modules/clipping/ema_clipping.py,sha256=pLeNuEBLpJ74io2sHn_ZVYaQ6ydEfhpVfVEX2bFttd0,5947
|
|
18
18
|
torchzero/modules/clipping/growth_clipping.py,sha256=OD-kdia2Rn-DvYlYV6EZlGPDVTh9tj-W9mpiZPc3hOQ,6772
|
|
19
|
-
torchzero/modules/experimental/__init__.py,sha256=
|
|
20
|
-
torchzero/modules/experimental/absoap.py,sha256=
|
|
21
|
-
torchzero/modules/experimental/adadam.py,sha256=
|
|
22
|
-
torchzero/modules/experimental/adamY.py,sha256=
|
|
23
|
-
torchzero/modules/experimental/adasoap.py,sha256=
|
|
24
|
-
torchzero/modules/experimental/algebraic_newton.py,sha256=
|
|
25
|
-
torchzero/modules/experimental/curveball.py,sha256=
|
|
26
|
-
torchzero/modules/experimental/
|
|
27
|
-
torchzero/modules/experimental/
|
|
19
|
+
torchzero/modules/experimental/__init__.py,sha256=fEPDYDl7qhaFoferDRmG3ehwuqSvx4Vt2uOz0Y7h4to,483
|
|
20
|
+
torchzero/modules/experimental/absoap.py,sha256=Z4MS4pDPSQ9IaTk8g57OfrsWcYVOT72x533KKtn2Zxk,13512
|
|
21
|
+
torchzero/modules/experimental/adadam.py,sha256=OAPF1-NUbg79V3QOTYzsQlRC97C7XHj5boOLDqLz3PE,4029
|
|
22
|
+
torchzero/modules/experimental/adamY.py,sha256=g1pAHwgdyDdKvObZ67lCSc36L99tl5jlQgOr4lMJCDo,4595
|
|
23
|
+
torchzero/modules/experimental/adasoap.py,sha256=PduidICEGYICIvlYysYCFZF7-QhNX0YlhHfPhLONnUs,11247
|
|
24
|
+
torchzero/modules/experimental/algebraic_newton.py,sha256=sq5ZD_j_EtlxIjNnS0rKKwTSG_JuwsZOg9ZMMQTuQm0,5154
|
|
25
|
+
torchzero/modules/experimental/curveball.py,sha256=Uk30uLEztTHD5IUJLJm9Nn3x31DF9kQHmeLFhc065us,3262
|
|
26
|
+
torchzero/modules/experimental/gradmin.py,sha256=iJmEvDEdVdck0C-94pY3iGxnIoNv6Fu6vj3f7lS6aQM,3686
|
|
27
|
+
torchzero/modules/experimental/newton_solver.py,sha256=iGI2LHLaZd2ovpbq1Vogs76os0zWG7VwM7nUz8RzxVg,3071
|
|
28
28
|
torchzero/modules/experimental/reduce_outward_lr.py,sha256=kjtRwepBGBca77ToM-lw3b8ywptMtmSdC_jQfjJAwlY,1184
|
|
29
|
-
torchzero/modules/experimental/
|
|
30
|
-
torchzero/modules/experimental/
|
|
29
|
+
torchzero/modules/experimental/soapy.py,sha256=Ishd2Jj6BbhjrLyC48zf-cjMmA1kJb_uKXESQBIML_s,10990
|
|
30
|
+
torchzero/modules/experimental/spectral.py,sha256=8_n208V2yPY3z5pCym-FvwO7DGFhozNgWlpIBtQSdrI,12139
|
|
31
|
+
torchzero/modules/experimental/subspace_preconditioners.py,sha256=WnHpga7Kx4-N2xU5vP3uUHRER70ymyNJCWbSx2zXWOk,4976
|
|
31
32
|
torchzero/modules/experimental/tropical_newton.py,sha256=uq66ouhgrgc8iYGozDQ3_rtbubj8rKRwb1jfcdnlpHg,4903
|
|
32
33
|
torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
|
|
33
34
|
torchzero/modules/grad_approximation/fdm.py,sha256=2PNNBIMup1xlOwLFAwAS3xAVd-7GGVyerMeKH1ug9LQ,3591
|
|
@@ -70,25 +71,25 @@ torchzero/modules/optimizers/orthograd.py,sha256=5BLnNJTYuGUClHmlxaXZ1jNvBR4zSFD
|
|
|
70
71
|
torchzero/modules/optimizers/rmsprop.py,sha256=d10Y9Ck-391tVysO3xMHg3g2Pe0UEZplgebEyDYi3Z4,4333
|
|
71
72
|
torchzero/modules/optimizers/rprop.py,sha256=n4k5-9F3ppH0Xl-4l4vNXfqVf2r67vMPCkstUaQKPLw,10974
|
|
72
73
|
torchzero/modules/optimizers/shampoo.py,sha256=AHHV6d71DqKDPCg52ShWIPIRSGtWkMc1v1XwXgDG3qY,8606
|
|
73
|
-
torchzero/modules/optimizers/soap.py,sha256=
|
|
74
|
+
torchzero/modules/optimizers/soap.py,sha256=Kf2BAtIf2QY1V2ZJcUjRLcp2WfIVLd3mNclnaT3Nmds,11520
|
|
74
75
|
torchzero/modules/optimizers/sophia_h.py,sha256=8pSlYVm66xWplzdP8MX3MCTzzIYHsxGzDEXJKA03Zgg,4279
|
|
75
76
|
torchzero/modules/projections/__init__.py,sha256=OCxlh_-Tx-xpl31X03CeFJH9XveH563oEsWc8rUvX0A,196
|
|
76
77
|
torchzero/modules/projections/dct.py,sha256=wxaEV6dTNiOqW_n2UHX0De6mMXTKDXK6UNcMNI4Rogk,2373
|
|
77
78
|
torchzero/modules/projections/fft.py,sha256=OpCcEM1-A2dgk1umwRsBsvK7ObiHtsBKlkkcw0IX83Q,2961
|
|
78
79
|
torchzero/modules/projections/galore.py,sha256=c9CZ0kHxpKEoyfc_lnmeHOkNp55jCppb7onN5YmWnN8,242
|
|
79
|
-
torchzero/modules/projections/projection.py,sha256=
|
|
80
|
+
torchzero/modules/projections/projection.py,sha256=aYufSD3ftRUqVScPmqxwEFgP1P8ioxM8z9eyzaL7d10,10147
|
|
80
81
|
torchzero/modules/projections/structural.py,sha256=QaCGHmzHCXj46sM-XZ5XlYU9BnuRKI2ReR3LE8y2R4g,5740
|
|
81
82
|
torchzero/modules/quasi_newton/__init__.py,sha256=0iOlX73PHj9lQS3_2cJ5lyCdas904MnFfIvR8Popvzw,402
|
|
82
83
|
torchzero/modules/quasi_newton/cg.py,sha256=h-di1oKKP1tDoh-LogBRIRCp2UF9GA6XjEJPlX6xXf4,9322
|
|
83
|
-
torchzero/modules/quasi_newton/lbfgs.py,sha256=
|
|
84
|
-
torchzero/modules/quasi_newton/lsr1.py,sha256=
|
|
84
|
+
torchzero/modules/quasi_newton/lbfgs.py,sha256=U7FKNqFTRdabf1_UUCCEO3JoDlOnWzGBhYNvVg138gg,9199
|
|
85
|
+
torchzero/modules/quasi_newton/lsr1.py,sha256=BuoztcRo0lm7WW3vKtDQcfKizF-9WPItOY_X9Ng1ZS8,6033
|
|
85
86
|
torchzero/modules/quasi_newton/olbfgs.py,sha256=2YAOXlMnPGw22sNcIMH1hmggzAXQRbN59RSPUZNKUZY,8352
|
|
86
|
-
torchzero/modules/quasi_newton/quasi_newton.py,sha256=
|
|
87
|
+
torchzero/modules/quasi_newton/quasi_newton.py,sha256=5FutzdBNpx6P8Qun9LjXd-rsy2nY2bkpQ0z0cLPnbJo,17373
|
|
87
88
|
torchzero/modules/quasi_newton/experimental/__init__.py,sha256=3qpZGgdsx6wpoafWaNWx-eamRl1FuxVCWQZq8Y7Cl98,39
|
|
88
|
-
torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=
|
|
89
|
-
torchzero/modules/second_order/__init__.py,sha256=
|
|
89
|
+
torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=ec6JKYX89xA_UlY9VrMB3hBjDyNKwkalS_4JQGA1qOY,10762
|
|
90
|
+
torchzero/modules/second_order/__init__.py,sha256=jolCGaIVkID9hpxgx0Tc22wgjVlwuWekWjKTMe5jKXw,114
|
|
90
91
|
torchzero/modules/second_order/newton.py,sha256=XNhscAuWwxOUwps3sUrxc2ExgkNFbilnAdszrCvQxFg,5845
|
|
91
|
-
torchzero/modules/second_order/newton_cg.py,sha256=
|
|
92
|
+
torchzero/modules/second_order/newton_cg.py,sha256=stVySgo7tmvntd-tuAzThzpWmZzfTnmn8ISQa5Oi4yw,2872
|
|
92
93
|
torchzero/modules/second_order/nystrom.py,sha256=ZyCWrde-_-Ednj46jafuvBOzG3nC-3cPYGr-HytZbsE,6073
|
|
93
94
|
torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
|
|
94
95
|
torchzero/modules/smoothing/gaussian.py,sha256=YlT_G4MqAVkiWG56RHAwgt5SSPISpvQZQbSLh8mhF3I,6153
|
|
@@ -116,13 +117,14 @@ torchzero/utils/python_tools.py,sha256=RFBqNj8w52dpJ983pUPPDbg2x1MX_-SsBnBMffWGG
|
|
|
116
117
|
torchzero/utils/tensorlist.py,sha256=qSbiliVo1euFAksdHHHRbPUdYYxfkw1dvhpXj71wGy0,53162
|
|
117
118
|
torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
|
|
118
119
|
torchzero/utils/linalg/__init__.py,sha256=Dzbho3_z7JDdKzYD-QdLArg0ZEoC2BVGdlE3JoAnXHQ,272
|
|
120
|
+
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
119
121
|
torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
|
|
120
122
|
torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
|
|
121
123
|
torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
|
|
122
|
-
torchzero/utils/linalg/solve.py,sha256=
|
|
124
|
+
torchzero/utils/linalg/solve.py,sha256=P0PMi0zro3G3Rd0X-JeoLk7tqYDB0js0aB4bpQ0OABU,5235
|
|
123
125
|
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
124
|
-
torchzero-0.3.
|
|
125
|
-
torchzero-0.3.
|
|
126
|
-
torchzero-0.3.
|
|
127
|
-
torchzero-0.3.
|
|
128
|
-
torchzero-0.3.
|
|
126
|
+
torchzero-0.3.8.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
|
|
127
|
+
torchzero-0.3.8.dist-info/METADATA,sha256=vj5aue0pVG8xNStpOEvPfln422K5fpV3BKF-H7ZlhRQ,13941
|
|
128
|
+
torchzero-0.3.8.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
|
129
|
+
torchzero-0.3.8.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
|
|
130
|
+
torchzero-0.3.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|