torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/__init__.py
CHANGED
torchzero/core/__init__.py
CHANGED
|
@@ -3,6 +3,6 @@ from .module import Chainable, Module
|
|
|
3
3
|
from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
|
|
4
4
|
|
|
5
5
|
# order is important to avoid circular imports
|
|
6
|
-
from .modular import
|
|
6
|
+
from .modular import Optimizer
|
|
7
7
|
from .functional import apply, step, step_tensors, update
|
|
8
8
|
from .chain import Chain, maybe_chain
|
torchzero/core/functional.py
CHANGED
|
@@ -96,7 +96,7 @@ def step_tensors(
|
|
|
96
96
|
objective.updates = list(tensors)
|
|
97
97
|
|
|
98
98
|
# step with modules
|
|
99
|
-
# this won't update parameters in-place because objective.
|
|
99
|
+
# this won't update parameters in-place because objective.Optimizer is None
|
|
100
100
|
objective = _chain_step(objective, modules)
|
|
101
101
|
|
|
102
102
|
# return updates
|
torchzero/core/modular.py
CHANGED
|
@@ -15,7 +15,7 @@ from .objective import Objective
|
|
|
15
15
|
class _EvalCounterClosure:
|
|
16
16
|
"""keeps track of how many times closure has been evaluated, and sets closure return"""
|
|
17
17
|
__slots__ = ("modular", "closure")
|
|
18
|
-
def __init__(self, modular: "
|
|
18
|
+
def __init__(self, modular: "Optimizer", closure):
|
|
19
19
|
self.modular = modular
|
|
20
20
|
self.closure = closure
|
|
21
21
|
|
|
@@ -46,9 +46,9 @@ def flatten_modules(*modules: Chainable) -> list[Module]:
|
|
|
46
46
|
return flat
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
# have to inherit from
|
|
49
|
+
# have to inherit from Optimizer to support lr schedulers
|
|
50
50
|
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
51
|
-
class
|
|
51
|
+
class Optimizer(torch.optim.Optimizer):
|
|
52
52
|
"""Chains multiple modules into an optimizer.
|
|
53
53
|
|
|
54
54
|
Args:
|
|
@@ -62,7 +62,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
62
62
|
param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
|
|
63
63
|
|
|
64
64
|
def __init__(self, params: Params | torch.nn.Module, *modules: Module):
|
|
65
|
-
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `
|
|
65
|
+
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Optimizer`")
|
|
66
66
|
self.model: torch.nn.Module | None = None
|
|
67
67
|
"""The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
|
|
68
68
|
if isinstance(params, torch.nn.Module):
|
|
@@ -229,5 +229,5 @@ class Modular(torch.optim.Optimizer):
|
|
|
229
229
|
return self._closure_return
|
|
230
230
|
|
|
231
231
|
def __repr__(self):
|
|
232
|
-
return f'
|
|
232
|
+
return f'Optimizer({", ".join(str(m) for m in self.modules)})'
|
|
233
233
|
|
torchzero/core/module.py
CHANGED
|
@@ -35,7 +35,7 @@ class Module(ABC):
|
|
|
35
35
|
|
|
36
36
|
# settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
|
|
37
37
|
# 0 - this module specific per-parameter setting overrides set via `set_param_groups` - highest priority
|
|
38
|
-
# 1 - global per-parameter setting overrides in param_groups passed to
|
|
38
|
+
# 1 - global per-parameter setting overrides in param_groups passed to Optimizer - medium priority
|
|
39
39
|
# 2 - `defaults` - lowest priority
|
|
40
40
|
self.settings: defaultdict[torch.Tensor, ChainMap[str, Any]] = defaultdict(lambda: ChainMap({}, {}, self.defaults))
|
|
41
41
|
"""per-parameter settings."""
|
|
@@ -273,7 +273,7 @@ class Module(ABC):
|
|
|
273
273
|
return state_dict
|
|
274
274
|
|
|
275
275
|
def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
|
|
276
|
-
"""loads state_dict, ``id_to_tensor`` is passed by ``
|
|
276
|
+
"""loads state_dict, ``id_to_tensor`` is passed by ``Optimizer``"""
|
|
277
277
|
# load state
|
|
278
278
|
state = state_dict['state']
|
|
279
279
|
self.state.clear()
|
torchzero/core/objective.py
CHANGED
|
@@ -20,7 +20,7 @@ from ..utils.derivatives import (
|
|
|
20
20
|
from ..utils.thoad_tools import thoad_derivatives, thoad_single_tensor, lazy_thoad
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
from .modular import
|
|
23
|
+
from .modular import Optimizer
|
|
24
24
|
from .module import Module
|
|
25
25
|
|
|
26
26
|
def _closure_backward(closure, params, backward, retain_graph, create_graph):
|
|
@@ -135,13 +135,13 @@ class Objective:
|
|
|
135
135
|
model (torch.nn.Module | None, optional):
|
|
136
136
|
``torch.nn.Module`` object, needed for a few modules that require access to the model. Defaults to None.
|
|
137
137
|
current_step (int, optional):
|
|
138
|
-
number of times ``
|
|
138
|
+
number of times ``Optimizer.step()`` has been called, starting at 0. Defaults to 0.
|
|
139
139
|
parent (Objective | None, optional):
|
|
140
140
|
parent ``Objective`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
|
|
141
141
|
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
142
142
|
e.g. when projecting. Defaults to None.
|
|
143
|
-
modular (
|
|
144
|
-
Top-level ``
|
|
143
|
+
modular (Optimizer | None, optional):
|
|
144
|
+
Top-level ``Optimizer`` optimizer. Defaults to None.
|
|
145
145
|
storage (dict | None, optional):
|
|
146
146
|
additional kwargs passed to ``step`` to control some module-specific behavior. Defaults to None.
|
|
147
147
|
|
|
@@ -154,7 +154,7 @@ class Objective:
|
|
|
154
154
|
model: torch.nn.Module | None = None,
|
|
155
155
|
current_step: int = 0,
|
|
156
156
|
parent: "Objective | None" = None,
|
|
157
|
-
modular: "
|
|
157
|
+
modular: "Optimizer | None" = None,
|
|
158
158
|
storage: dict | None = None,
|
|
159
159
|
):
|
|
160
160
|
self.params: list[torch.Tensor] = list(params)
|
|
@@ -175,8 +175,8 @@ class Objective:
|
|
|
175
175
|
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
176
176
|
e.g. when projecting."""
|
|
177
177
|
|
|
178
|
-
self.modular: "
|
|
179
|
-
"""Top-level ``
|
|
178
|
+
self.modular: "Optimizer | None" = modular
|
|
179
|
+
"""Top-level ``Optimizer`` optimizer, ``None`` if it wasn't specified."""
|
|
180
180
|
|
|
181
181
|
self.updates: list[torch.Tensor] | None = None
|
|
182
182
|
"""
|
|
@@ -222,7 +222,7 @@ class Objective:
|
|
|
222
222
|
# """Storage for any other data, such as hessian estimates, etc."""
|
|
223
223
|
|
|
224
224
|
self.attrs: dict = {}
|
|
225
|
-
"""attributes, ``
|
|
225
|
+
"""attributes, ``Optimizer.attrs`` is updated with this after each step.
|
|
226
226
|
This attribute should always be modified in-place"""
|
|
227
227
|
|
|
228
228
|
if storage is None: storage = {}
|
|
@@ -231,7 +231,7 @@ class Objective:
|
|
|
231
231
|
This attribute should always be modified in-place"""
|
|
232
232
|
|
|
233
233
|
self.should_terminate: bool | None = None
|
|
234
|
-
"""termination criteria, ``
|
|
234
|
+
"""termination criteria, ``Optimizer.should_terminate`` is set to this after each step if not ``None``"""
|
|
235
235
|
|
|
236
236
|
self.temp: Any = cast(Any, None)
|
|
237
237
|
"""temporary storage, ``Module.update`` can set this and ``Module.apply`` access via ``objective.poptemp()``.
|
|
@@ -756,7 +756,7 @@ class Objective:
|
|
|
756
756
|
if g_list is not None and self.grads is None:
|
|
757
757
|
self.grads = list(g_list)
|
|
758
758
|
|
|
759
|
-
return f, g_list, H
|
|
759
|
+
return f, g_list, H.detach()
|
|
760
760
|
|
|
761
761
|
@torch.no_grad
|
|
762
762
|
def derivatives(self, order: int, at_x0: bool, method:DerivativesMethod="batched_autograd"):
|
torchzero/core/transform.py
CHANGED
|
@@ -233,7 +233,7 @@ class TensorTransform(Transform):
|
|
|
233
233
|
if self._uses_grad: grads = objective.get_grads()
|
|
234
234
|
else: grads = None # better explicitly set to None rather than objective.grads because it shouldn't be used
|
|
235
235
|
|
|
236
|
-
if self._uses_loss: loss = objective.get_loss(backward=
|
|
236
|
+
if self._uses_loss: loss = objective.get_loss(backward=True)
|
|
237
237
|
else: loss = None
|
|
238
238
|
|
|
239
239
|
return grads, loss
|
torchzero/linalg/__init__.py
CHANGED
|
@@ -3,8 +3,9 @@ from . import linear_operator
|
|
|
3
3
|
from .matrix_power import (
|
|
4
4
|
matrix_power_eigh,
|
|
5
5
|
matrix_power_svd,
|
|
6
|
+
MatrixPowerMethod,
|
|
6
7
|
)
|
|
7
|
-
from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize
|
|
8
|
+
from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize,OrthogonalizeMethod
|
|
8
9
|
from .qr import qr_householder
|
|
9
10
|
from .solve import cg, nystrom_sketch_and_solve, nystrom_pcg
|
|
10
|
-
from .eigh import nystrom_approximation
|
|
11
|
+
from .eigh import nystrom_approximation, regularize_eigh
|
torchzero/linalg/eigh.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
|
+
|
|
2
3
|
import torch
|
|
3
|
-
from .linalg_utils import mm
|
|
4
4
|
|
|
5
|
+
from . import torch_linalg
|
|
6
|
+
from .linalg_utils import mm
|
|
7
|
+
from .orthogonalize import OrthogonalizeMethod, orthogonalize
|
|
8
|
+
from .svd import tall_reduced_svd_via_eigh
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
# https://arxiv.org/pdf/2110.02820
|
|
@@ -11,6 +15,8 @@ def nystrom_approximation(
|
|
|
11
15
|
ndim: int,
|
|
12
16
|
rank: int,
|
|
13
17
|
device,
|
|
18
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
19
|
+
eigv_tol: float = 0,
|
|
14
20
|
dtype = torch.float32,
|
|
15
21
|
generator = None,
|
|
16
22
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -20,7 +26,7 @@ def nystrom_approximation(
|
|
|
20
26
|
A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
|
|
21
27
|
# basis
|
|
22
28
|
O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
|
|
23
|
-
O
|
|
29
|
+
O = orthogonalize(O, method=orthogonalize_method) # Thin QR decomposition # pylint:disable=not-callable
|
|
24
30
|
|
|
25
31
|
# Y = AΩ
|
|
26
32
|
AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
|
|
@@ -29,6 +35,219 @@ def nystrom_approximation(
|
|
|
29
35
|
Yv = AO + v*O # Shift for stability
|
|
30
36
|
C = torch.linalg.cholesky_ex(O.mT @ Yv)[0] # pylint:disable=not-callable
|
|
31
37
|
B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
|
|
32
|
-
|
|
33
|
-
|
|
38
|
+
|
|
39
|
+
# Q, S, _ = torch_linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
|
|
40
|
+
# B is (ndim, rank) so we can use eigendecomp of (rank, rank)
|
|
41
|
+
Q, S = tall_reduced_svd_via_eigh(B, tol=eigv_tol, retry_float64=True)
|
|
42
|
+
|
|
43
|
+
L = S.pow(2) - v
|
|
44
|
+
return L, Q
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def regularize_eigh(
|
|
48
|
+
L: torch.Tensor,
|
|
49
|
+
Q: torch.Tensor,
|
|
50
|
+
truncate: int | None = None,
|
|
51
|
+
tol: float | None = None,
|
|
52
|
+
damping: float = 0,
|
|
53
|
+
rdamping: float = 0,
|
|
54
|
+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
|
|
55
|
+
"""Applies regularization to eigendecomposition. Returns ``(L, Q)``.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
L (torch.Tensor): eigenvalues, shape ``(rank,)``.
|
|
59
|
+
Q (torch.Tensor): eigenvectors, shape ``(n, rank)``.
|
|
60
|
+
truncate (int | None, optional):
|
|
61
|
+
keeps top ``truncate`` eigenvalues. Defaults to None.
|
|
62
|
+
tol (float | None, optional):
|
|
63
|
+
all eigenvalues smaller than largest eigenvalue times ``tol`` are removed. Defaults to None.
|
|
64
|
+
damping (float | None, optional): scalar added to eigenvalues. Defaults to 0.
|
|
65
|
+
rdamping (float | None, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
|
|
66
|
+
"""
|
|
67
|
+
# remove non-finite eigenvalues
|
|
68
|
+
finite = L.isfinite()
|
|
69
|
+
if finite.any():
|
|
70
|
+
L = L[finite]
|
|
71
|
+
Q = Q[:, finite]
|
|
72
|
+
else:
|
|
73
|
+
return None, None
|
|
74
|
+
|
|
75
|
+
# largest finite!!! eigval
|
|
76
|
+
L_max = L[-1] # L is sorted in ascending order
|
|
77
|
+
|
|
78
|
+
# remove small eigenvalues relative to largest
|
|
79
|
+
if tol is not None:
|
|
80
|
+
indices = L > tol * L_max
|
|
81
|
+
L = L[indices]
|
|
82
|
+
Q = Q[:, indices]
|
|
83
|
+
|
|
84
|
+
# truncate to rank (L is ordered in ascending order)
|
|
85
|
+
if truncate is not None:
|
|
86
|
+
L = L[-truncate:]
|
|
87
|
+
Q = Q[:, -truncate:]
|
|
88
|
+
|
|
89
|
+
# damping
|
|
90
|
+
d = damping + rdamping * L_max
|
|
91
|
+
if d != 0:
|
|
92
|
+
L += d
|
|
93
|
+
|
|
34
94
|
return L, Q
|
|
95
|
+
|
|
96
|
+
def eigh_plus_uuT(
|
|
97
|
+
L: torch.Tensor,
|
|
98
|
+
Q: torch.Tensor,
|
|
99
|
+
u: torch.Tensor,
|
|
100
|
+
alpha: float = 1,
|
|
101
|
+
tol: float | None = None,
|
|
102
|
+
retry_float64: bool = False,
|
|
103
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
104
|
+
"""
|
|
105
|
+
compute eigendecomposition of Q L Q^T + alpha * (u u^T) where Q is ``(m, rank)`` and L is ``(rank, )`` and u is ``(m, )``
|
|
106
|
+
"""
|
|
107
|
+
if tol is None: tol = torch.finfo(Q.dtype).eps
|
|
108
|
+
z = Q.T @ u # (rank,)
|
|
109
|
+
|
|
110
|
+
# component of u orthogonal to the column space of Q
|
|
111
|
+
res = u - Q @ z # (m,)
|
|
112
|
+
beta = torch.linalg.vector_norm(res) # pylint:disable=not-callable
|
|
113
|
+
|
|
114
|
+
if beta < tol:
|
|
115
|
+
# u is already in the column space of Q
|
|
116
|
+
B = L.diag_embed().add_(z.outer(z), alpha=alpha) # (rank, rank)
|
|
117
|
+
L_prime, S = torch_linalg.eigh(B, retry_float64=retry_float64)
|
|
118
|
+
Q_prime = Q @ S
|
|
119
|
+
return L_prime, Q_prime
|
|
120
|
+
|
|
121
|
+
# normalize the orthogonal component to get a new orthonormal vector
|
|
122
|
+
v = res / beta # (m, )
|
|
123
|
+
|
|
124
|
+
# project and compute new eigendecomposition
|
|
125
|
+
D_diag = torch.cat([L, torch.tensor([0.0], device=Q.device, dtype=Q.dtype)])
|
|
126
|
+
w = torch.cat([z, beta.unsqueeze(0)]) # Shape: (rank+1,)
|
|
127
|
+
B = D_diag.diag_embed().add_(w.outer(w), alpha=alpha)
|
|
128
|
+
|
|
129
|
+
L_prime, S = torch_linalg.eigh(B, retry_float64=retry_float64)
|
|
130
|
+
|
|
131
|
+
# unproject and sort
|
|
132
|
+
basis = torch.cat([Q, v.unsqueeze(-1)], dim=1) # (m, rank+1)
|
|
133
|
+
Q_prime = basis @ S # (m, rank+1)
|
|
134
|
+
|
|
135
|
+
idx = torch.argsort(L_prime)
|
|
136
|
+
L_prime = L_prime[idx]
|
|
137
|
+
Q_prime = Q_prime[:, idx]
|
|
138
|
+
|
|
139
|
+
return L_prime, Q_prime
|
|
140
|
+
|
|
141
|
+
def eigh_plus_UUT(
|
|
142
|
+
L: torch.Tensor,
|
|
143
|
+
Q: torch.Tensor,
|
|
144
|
+
U: torch.Tensor,
|
|
145
|
+
alpha: float = 1,
|
|
146
|
+
tol = None,
|
|
147
|
+
retry_float64: bool = False,
|
|
148
|
+
):
|
|
149
|
+
"""
|
|
150
|
+
compute eigendecomposition of Q L Q^T + alpha * (U U^T), where Q is ``(m, rank)`` and L is ``(rank, )``,
|
|
151
|
+
U is ``(m, k)`` where k is rank of correction
|
|
152
|
+
"""
|
|
153
|
+
if U.size(1) == 1:
|
|
154
|
+
return eigh_plus_uuT(L, Q, U[:,0], alpha=alpha, tol=tol, retry_float64=retry_float64)
|
|
155
|
+
|
|
156
|
+
if tol is None: tol = torch.finfo(Q.dtype).eps
|
|
157
|
+
m, r = Q.shape
|
|
158
|
+
|
|
159
|
+
Z = Q.T @ U # (r, k)
|
|
160
|
+
U_res = U - Q @ Z # (m, k)
|
|
161
|
+
|
|
162
|
+
# find cols of U not in col space of Q
|
|
163
|
+
res_norms = torch.linalg.vector_norm(U_res, dim=0) # pylint:disable=not-callable
|
|
164
|
+
new_indices = torch.where(res_norms > tol)[0]
|
|
165
|
+
k_prime = len(new_indices)
|
|
166
|
+
|
|
167
|
+
if k_prime == 0:
|
|
168
|
+
# all cols are in Q
|
|
169
|
+
B = Q
|
|
170
|
+
C = Z # (r x k)
|
|
171
|
+
r_new = r
|
|
172
|
+
else:
|
|
173
|
+
# orthonormalize directions that aren't in Q
|
|
174
|
+
U_new = U_res[:, new_indices]
|
|
175
|
+
Q_u, _ = torch_linalg.qr(U_new, mode='reduced', retry_float64=retry_float64)
|
|
176
|
+
B = torch.hstack([Q, Q_u])
|
|
177
|
+
C = torch.vstack([Z, Q_u.T @ U])
|
|
178
|
+
r_new = r + k_prime
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# project and compute new eigendecomposition
|
|
182
|
+
A_proj = torch.zeros((r_new, r_new), device=Q.device, dtype=Q.dtype)
|
|
183
|
+
A_proj[:r, :r] = L.diag_embed()
|
|
184
|
+
A_proj.addmm_(C, C.T, alpha=alpha)
|
|
185
|
+
|
|
186
|
+
L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
|
|
187
|
+
|
|
188
|
+
# unproject and sort
|
|
189
|
+
Q_prime = B @ S
|
|
190
|
+
idx = torch.argsort(L_prime)
|
|
191
|
+
L_prime = L_prime[idx]
|
|
192
|
+
Q_prime = Q_prime[:, idx]
|
|
193
|
+
|
|
194
|
+
return L_prime, Q_prime
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def eigh_plus_UVT_symmetrize(
|
|
198
|
+
Q: torch.Tensor,
|
|
199
|
+
L: torch.Tensor,
|
|
200
|
+
U: torch.Tensor,
|
|
201
|
+
V: torch.Tensor,
|
|
202
|
+
alpha: float,
|
|
203
|
+
retry_float64: bool = False,
|
|
204
|
+
|
|
205
|
+
):
|
|
206
|
+
"""
|
|
207
|
+
Q is ``(m, rank)``; L is ``(rank, )``; U and V are the low rank correction such that U V^T is ``(m, m)``.
|
|
208
|
+
|
|
209
|
+
This computes eigendecomposition of A, where
|
|
210
|
+
|
|
211
|
+
``M = Q diag(L) Q^T + alpha * (U V^T)``;
|
|
212
|
+
|
|
213
|
+
``A = (M + M^T) / 2``
|
|
214
|
+
"""
|
|
215
|
+
m, rank = Q.shape
|
|
216
|
+
_, k = V.shape
|
|
217
|
+
|
|
218
|
+
# project U and V out of the Q subspace via Gram-schmidt
|
|
219
|
+
Q_T_U = Q.T @ U
|
|
220
|
+
U_perp = U - Q @ Q_T_U
|
|
221
|
+
|
|
222
|
+
Q_T_V = Q.T @ V
|
|
223
|
+
V_perp = V - Q @ Q_T_V
|
|
224
|
+
|
|
225
|
+
R = torch.hstack([U_perp, V_perp])
|
|
226
|
+
Q_perp, _ = torch_linalg.qr(R, retry_float64=retry_float64)
|
|
227
|
+
|
|
228
|
+
Q_B = torch.hstack([Q, Q_perp])
|
|
229
|
+
r_B = Q_B.shape[1]
|
|
230
|
+
|
|
231
|
+
# project, symmetrize and compute new eigendecomposition
|
|
232
|
+
A_proj = torch.zeros((r_B, r_B), device=Q.device, dtype=Q.dtype)
|
|
233
|
+
A_proj[:rank, :rank] = L.diag_embed()
|
|
234
|
+
|
|
235
|
+
Q_perp_T_U = Q_perp.T @ U
|
|
236
|
+
Q_B_T_U = torch.vstack([Q_T_U, Q_perp_T_U])
|
|
237
|
+
|
|
238
|
+
Q_perp_T_V = Q_perp.T @ V
|
|
239
|
+
Q_B_T_V = torch.vstack([Q_T_V, Q_perp_T_V])
|
|
240
|
+
|
|
241
|
+
update_proj = Q_B_T_U @ Q_B_T_V.T + Q_B_T_V @ Q_B_T_U.T
|
|
242
|
+
A_proj.add_(update_proj, alpha=alpha/2)
|
|
243
|
+
|
|
244
|
+
L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
|
|
245
|
+
|
|
246
|
+
# unproject and sort
|
|
247
|
+
Q_prime = Q_B @ S
|
|
248
|
+
|
|
249
|
+
idx = torch.argsort(L_prime)
|
|
250
|
+
L_prime = L_prime[idx]
|
|
251
|
+
Q_prime = Q_prime[:, idx]
|
|
252
|
+
|
|
253
|
+
return L_prime, Q_prime
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
|
+
|
|
2
3
|
import torch
|
|
3
4
|
|
|
4
5
|
from ..utils.compile import allow_compile
|
|
@@ -49,9 +50,6 @@ def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Ten
|
|
|
49
50
|
|
|
50
51
|
return X.to(G.dtype)
|
|
51
52
|
|
|
52
|
-
# code from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
|
|
53
|
-
# Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
54
|
-
# Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
55
53
|
def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
|
|
56
54
|
"""
|
|
57
55
|
Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
|
|
@@ -87,7 +85,7 @@ def orthogonalize_via_qr(A: torch.Tensor):
|
|
|
87
85
|
return Q
|
|
88
86
|
|
|
89
87
|
OrthogonalizeMethod = Literal["newtonschulz", "svd", "qr"]
|
|
90
|
-
def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod
|
|
88
|
+
def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod) -> torch.Tensor:
|
|
91
89
|
if method == "newtonschulz": return zeropower_via_newtonschulz5(A)
|
|
92
90
|
if method == "svd": return zeropower_via_svd(A)
|
|
93
91
|
if method == "qr": return orthogonalize_via_qr(A)
|
torchzero/linalg/qr.py
CHANGED
|
@@ -2,6 +2,18 @@ from typing import Literal
|
|
|
2
2
|
import torch
|
|
3
3
|
from ..utils.compile import allow_compile
|
|
4
4
|
|
|
5
|
+
|
|
6
|
+
# super slow
|
|
7
|
+
# def cholesky_qr(A):
|
|
8
|
+
# """QR of (m, n) A via cholesky of (n, n) matrix"""
|
|
9
|
+
# AtA = A.T @ A
|
|
10
|
+
|
|
11
|
+
# L, _ = torch.linalg.cholesky_ex(AtA) # pylint:disable=not-callable
|
|
12
|
+
# R = L.T
|
|
13
|
+
|
|
14
|
+
# Q = torch.linalg.solve_triangular(R.T, A.T, upper=False).T # pylint:disable=not-callable
|
|
15
|
+
# return Q, R
|
|
16
|
+
|
|
5
17
|
# reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
|
|
6
18
|
@allow_compile
|
|
7
19
|
def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
torchzero/linalg/solve.py
CHANGED
|
@@ -25,15 +25,13 @@ def _make_A_mv_reg(A_mv: Callable, reg):
|
|
|
25
25
|
|
|
26
26
|
def _identity(x): return x
|
|
27
27
|
|
|
28
|
-
# TODO this is used in NystromSketchAndSolve
|
|
29
|
-
# I need to add alternative to it where it just shifts eigenvalues by reg and uses their reciprocal
|
|
30
28
|
def nystrom_sketch_and_solve(
|
|
31
29
|
L: torch.Tensor,
|
|
32
30
|
Q: torch.Tensor,
|
|
33
31
|
b: torch.Tensor,
|
|
34
32
|
reg: float = 1e-3,
|
|
35
33
|
) -> torch.Tensor:
|
|
36
|
-
"""Solves (Q diag(L) Q.T + reg*I)x = b
|
|
34
|
+
"""Solves ``(Q diag(L) Q.T + reg*I)x = b``. Becomes super unstable with reg smaller than like 1e-5.
|
|
37
35
|
|
|
38
36
|
Args:
|
|
39
37
|
L (torch.Tensor): eigenvalues, like from ``nystrom_approximation``
|
torchzero/linalg/svd.py
CHANGED
|
@@ -1,20 +1,47 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from . import torch_linalg
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def tall_reduced_svd_via_eigh(A: torch.Tensor, tol: float = 0, retry_float64:bool=False):
|
|
7
|
+
"""
|
|
8
|
+
Given a tall matrix A of size (m, n), computes U and S from the reduced SVD(A)
|
|
9
|
+
using the eigendecomposition of (n, n) matrix which is faster than direct SVD when m >= n.
|
|
10
|
+
|
|
11
|
+
This truncates small singular values that would causes nans,
|
|
12
|
+
so the returned U and S can have reduced dimension ``k <= n``.
|
|
13
|
+
|
|
14
|
+
Returns U of size ``(m, k)`` and S of size ``(k, )``.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
A (torch.Tensor): A tall matrix of size (m, n) with m >= n.
|
|
18
|
+
tol (float): Tolerance for truncating small singular values. Singular values
|
|
19
|
+
less than ``tol * max_singular_value`` will be discarded.
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
# if m < n, A.T A will be low rank and we can't use eigh
|
|
24
|
+
m, n = A.size()
|
|
25
|
+
if m < n:
|
|
26
|
+
U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
|
|
27
|
+
return U, S
|
|
28
|
+
|
|
29
|
+
M = A.mH @ A # n,n
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
L, Q = torch_linalg.eigh(M, retry_float64=retry_float64)
|
|
33
|
+
except torch.linalg.LinAlgError:
|
|
34
|
+
U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
|
|
35
|
+
return U, S
|
|
36
|
+
|
|
37
|
+
L = torch.flip(L, dims=[-1])
|
|
38
|
+
Q = torch.flip(Q, dims=[-1])
|
|
39
|
+
|
|
40
|
+
indices = L > tol * L[0] # L[0] is the max eigenvalue
|
|
41
|
+
L = L[indices]
|
|
42
|
+
Q = Q[:, indices]
|
|
43
|
+
|
|
44
|
+
S = L.sqrt()
|
|
45
|
+
U = (A @ Q) / S
|
|
46
|
+
|
|
47
|
+
return U, S
|
torchzero/modules/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from . import experimental
|
|
2
|
+
from .adaptive import *
|
|
3
|
+
from .adaptive import lre_optimizers as lre
|
|
2
4
|
from .clipping import *
|
|
3
5
|
from .conjugate_gradient import *
|
|
4
6
|
from .grad_approximation import *
|
|
@@ -7,9 +9,9 @@ from .line_search import *
|
|
|
7
9
|
from .misc import *
|
|
8
10
|
from .momentum import *
|
|
9
11
|
from .ops import *
|
|
10
|
-
from .adaptive import *
|
|
11
12
|
from .projections import *
|
|
12
13
|
from .quasi_newton import *
|
|
14
|
+
from .restarts import *
|
|
13
15
|
from .second_order import *
|
|
14
16
|
from .smoothing import *
|
|
15
17
|
from .step_size import *
|
|
@@ -18,5 +20,4 @@ from .trust_region import *
|
|
|
18
20
|
from .variance_reduction import *
|
|
19
21
|
from .weight_decay import *
|
|
20
22
|
from .wrappers import *
|
|
21
|
-
from .
|
|
22
|
-
from .zeroth_order import *
|
|
23
|
+
from .zeroth_order import *
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from . import lre_optimizers
|
|
2
|
+
from .adagrad import Adagrad, AdagradNorm, FullMatrixAdagrad
|
|
2
3
|
|
|
3
4
|
# from .curveball import CurveBall
|
|
4
5
|
# from .spectral import SpectralPreconditioner
|
|
@@ -8,14 +9,21 @@ from .adan import Adan
|
|
|
8
9
|
from .adaptive_heavyball import AdaptiveHeavyBall
|
|
9
10
|
from .aegd import AEGD
|
|
10
11
|
from .esgd import ESGD
|
|
11
|
-
from .lmadagrad import LMAdagrad
|
|
12
12
|
from .lion import Lion
|
|
13
|
+
from .ggt import GGT
|
|
13
14
|
from .mars import MARSCorrection
|
|
14
15
|
from .matrix_momentum import MatrixMomentum
|
|
15
|
-
from .msam import
|
|
16
|
+
from .msam import MSAM, MSAMMomentum
|
|
16
17
|
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
17
18
|
from .natural_gradient import NaturalGradient
|
|
18
19
|
from .orthograd import OrthoGrad, orthograd_
|
|
20
|
+
from .psgd import (
|
|
21
|
+
PSGDDenseNewton,
|
|
22
|
+
PSGDKronNewton,
|
|
23
|
+
PSGDKronWhiten,
|
|
24
|
+
PSGDLRANewton,
|
|
25
|
+
PSGDLRAWhiten,
|
|
26
|
+
)
|
|
19
27
|
from .rmsprop import RMSprop
|
|
20
28
|
from .rprop import (
|
|
21
29
|
BacktrackOnSignChange,
|