torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -3,21 +3,14 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
Module,
|
|
10
|
-
Transform,
|
|
11
|
-
Var,
|
|
12
|
-
apply_transform,
|
|
13
|
-
)
|
|
14
|
-
from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
15
|
-
from ..line_search import LineSearchBase
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
|
|
8
|
+
from ...utils import TensorList, safe_dict_update_, unpack_dicts, unpack_states
|
|
16
9
|
from ..quasi_newton.quasi_newton import HessianUpdateStrategy
|
|
17
|
-
from ..
|
|
10
|
+
from ..opt_utils import safe_clip
|
|
18
11
|
|
|
19
12
|
|
|
20
|
-
class ConguateGradientBase(
|
|
13
|
+
class ConguateGradientBase(TensorTransform, ABC):
|
|
21
14
|
"""Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
|
|
22
15
|
|
|
23
16
|
This is an abstract class, to use it, subclass it and override `get_beta`.
|
|
@@ -52,13 +45,8 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
52
45
|
"""
|
|
53
46
|
def __init__(self, defaults, clip_beta: bool, restart_interval: int | None | Literal['auto'], inner: Chainable | None = None):
|
|
54
47
|
if defaults is None: defaults = {}
|
|
55
|
-
defaults
|
|
56
|
-
defaults
|
|
57
|
-
super().__init__(defaults, uses_grad=False)
|
|
58
|
-
|
|
59
|
-
if inner is not None:
|
|
60
|
-
self.set_child('inner', inner)
|
|
61
|
-
|
|
48
|
+
safe_dict_update_(defaults, dict(restart_interval=restart_interval, clip_beta=clip_beta))
|
|
49
|
+
super().__init__(defaults, inner=inner)
|
|
62
50
|
|
|
63
51
|
def reset_for_online(self):
|
|
64
52
|
super().reset_for_online()
|
|
@@ -74,40 +62,38 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
74
62
|
"""returns beta"""
|
|
75
63
|
|
|
76
64
|
@torch.no_grad
|
|
77
|
-
def
|
|
78
|
-
tensors =
|
|
79
|
-
params =
|
|
80
|
-
|
|
81
|
-
step = self.global_state.get('step', 0) + 1
|
|
82
|
-
self.global_state['step'] = step
|
|
65
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
66
|
+
tensors = TensorList(tensors)
|
|
67
|
+
params = TensorList(params)
|
|
68
|
+
self.increment_counter("step", start=0)
|
|
83
69
|
|
|
84
70
|
# initialize on first step
|
|
85
|
-
if self.global_state.get('stage',
|
|
71
|
+
if self.global_state.get('stage', "first update") == "first update":
|
|
86
72
|
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
87
73
|
d_prev.copy_(tensors)
|
|
88
74
|
g_prev.copy_(tensors)
|
|
89
75
|
self.initialize(params, tensors)
|
|
90
|
-
self.global_state['stage'] =
|
|
76
|
+
self.global_state['stage'] = "first apply"
|
|
91
77
|
|
|
92
78
|
else:
|
|
93
79
|
# if `update_tensors` was called multiple times before `apply_tensors`,
|
|
94
80
|
# stage becomes 2
|
|
95
|
-
self.global_state['stage'] =
|
|
81
|
+
self.global_state['stage'] = "initialized"
|
|
96
82
|
|
|
97
83
|
@torch.no_grad
|
|
98
|
-
def
|
|
99
|
-
tensors =
|
|
84
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
85
|
+
tensors = TensorList(tensors)
|
|
100
86
|
step = self.global_state['step']
|
|
101
87
|
|
|
102
|
-
|
|
103
|
-
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
|
|
88
|
+
assert self.global_state['stage'] != "first update"
|
|
104
89
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
90
|
+
# on 1st apply we don't have previous gradients
|
|
91
|
+
# so just return tensors
|
|
92
|
+
if self.global_state['stage'] == "first apply":
|
|
93
|
+
self.global_state['stage'] = "initialized"
|
|
108
94
|
return tensors
|
|
109
95
|
|
|
110
|
-
params =
|
|
96
|
+
params = TensorList(params)
|
|
111
97
|
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
112
98
|
|
|
113
99
|
# get beta
|
|
@@ -119,10 +105,13 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
119
105
|
dir = tensors.add_(d_prev.mul_(beta))
|
|
120
106
|
d_prev.copy_(dir)
|
|
121
107
|
|
|
122
|
-
# resetting
|
|
108
|
+
# resetting every `reset_interval` steps, use step+1 to not reset on 1st step
|
|
109
|
+
# so if reset_interval=2, then 1st step collects g_prev and d_prev, then
|
|
110
|
+
# two steps will happen until reset.
|
|
123
111
|
restart_interval = settings[0]['restart_interval']
|
|
124
112
|
if restart_interval == 'auto': restart_interval = tensors.global_numel() + 1
|
|
125
|
-
|
|
113
|
+
|
|
114
|
+
if restart_interval is not None and (step + 1) % restart_interval == 0:
|
|
126
115
|
self.state.clear()
|
|
127
116
|
self.global_state.clear()
|
|
128
117
|
|
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
"""Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
|
|
2
|
+
from .adanystrom import AdaNystrom
|
|
3
|
+
from .common_directions_whiten import CommonDirectionsWhiten
|
|
4
|
+
from .coordinate_momentum import CoordinateMomentum
|
|
5
|
+
from .cubic_adam import CubicAdam, SubspaceCubicAdam
|
|
2
6
|
from .curveball import CurveBall
|
|
7
|
+
from .eigen_sr1 import EigenSR1
|
|
3
8
|
|
|
4
9
|
# from dct import DCTProjection
|
|
10
|
+
from .eigengrad import Eigengrad
|
|
5
11
|
from .fft import FFTProjection
|
|
6
12
|
from .gradmin import GradMin
|
|
7
13
|
from .higher_order_newton import HigherOrderNewton
|
|
8
14
|
from .l_infinity import InfinityNormTrustRegion
|
|
9
|
-
from .momentum import (
|
|
10
|
-
CoordinateMomentum,
|
|
11
|
-
NesterovEMASquared,
|
|
12
|
-
PrecenteredEMASquared,
|
|
13
|
-
SqrtNesterovEMASquared,
|
|
14
|
-
)
|
|
15
15
|
from .newton_solver import NewtonSolver
|
|
16
16
|
from .newtonnewton import NewtonNewton
|
|
17
17
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
18
18
|
from .scipy_newton_cg import ScipyNewtonCG
|
|
19
|
+
from .spsa1 import SPSA1
|
|
19
20
|
from .structural_projections import BlockPartition, TensorizeProjection
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# pylint: disable = non-ascii-name
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, TensorTransform
|
|
5
|
+
from ...linalg import (
|
|
6
|
+
OrthogonalizeMethod,
|
|
7
|
+
orthogonalize,
|
|
8
|
+
regularize_eigh,
|
|
9
|
+
torch_linalg,
|
|
10
|
+
)
|
|
11
|
+
from ...linalg.linear_operator import Eigendecomposition
|
|
12
|
+
from ..adaptive.lre_optimizers import LREOptimizerBase
|
|
13
|
+
from .eigengrad import _eigengrad_update_state_, eigengrad_apply
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def weighted_eigen_plus_rank1_mm(
|
|
17
|
+
# A1 = Q1 @ diag(L1) @ Q1.T
|
|
18
|
+
L1: torch.Tensor,
|
|
19
|
+
Q1: torch.Tensor,
|
|
20
|
+
|
|
21
|
+
# K2 = v2 @ v2.T
|
|
22
|
+
v2: torch.Tensor,
|
|
23
|
+
|
|
24
|
+
# second matrix
|
|
25
|
+
B: torch.Tensor,
|
|
26
|
+
|
|
27
|
+
# weights
|
|
28
|
+
w1: float,
|
|
29
|
+
w2: float,
|
|
30
|
+
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""
|
|
33
|
+
Computes ``(w1 * A1 + w2 * A2) @ B``, where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
34
|
+
|
|
35
|
+
Returns ``(n, k)``
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
39
|
+
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
40
|
+
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)``.
|
|
41
|
+
B (torch.Tensor): shape ``(n, k)``.
|
|
42
|
+
w1 (float): weight for A1.
|
|
43
|
+
w2 (float): weight for A2.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
# sketch A1
|
|
47
|
+
QTB = Q1.T @ B # (rank, k)
|
|
48
|
+
LQTB = L1.unsqueeze(1) * QTB # (rank, k)
|
|
49
|
+
sketch1 = Q1 @ LQTB # (n, k)
|
|
50
|
+
|
|
51
|
+
# skecth A2
|
|
52
|
+
vB = v2 @ B
|
|
53
|
+
sketch2 = v2.outer(vB)
|
|
54
|
+
|
|
55
|
+
return w1 * sketch1 + w2 * sketch2
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def adanystrom_update(
|
|
59
|
+
L1: torch.Tensor,
|
|
60
|
+
Q1: torch.Tensor,
|
|
61
|
+
v2: torch.Tensor,
|
|
62
|
+
w1: float,
|
|
63
|
+
w2: float,
|
|
64
|
+
oversampling_p: int,
|
|
65
|
+
rank: int,
|
|
66
|
+
eig_tol: float,
|
|
67
|
+
damping: float,
|
|
68
|
+
rdamping: float,
|
|
69
|
+
orthogonalize_method: OrthogonalizeMethod,
|
|
70
|
+
|
|
71
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
72
|
+
"""computes the Nyström approximation of ``(w1 * A1 + w2 * A2)``,
|
|
73
|
+
where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
74
|
+
|
|
75
|
+
returns L of shape ``(k, )`` and Q of shape ``(n, k)``.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
79
|
+
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
80
|
+
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)`` or ``(n, 1)``.
|
|
81
|
+
w1 (float): weight for A1.
|
|
82
|
+
w2 (float): weight for A2.
|
|
83
|
+
"""
|
|
84
|
+
n = Q1.shape[0]
|
|
85
|
+
device = Q1.device
|
|
86
|
+
dtype = Q1.dtype
|
|
87
|
+
l = rank + oversampling_p
|
|
88
|
+
|
|
89
|
+
# gaussian test matrix
|
|
90
|
+
Omega = torch.randn(n, l, device=device, dtype=dtype)
|
|
91
|
+
|
|
92
|
+
# sketch
|
|
93
|
+
AOmega = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Omega, w1, w2)
|
|
94
|
+
Q = orthogonalize(AOmega, orthogonalize_method)
|
|
95
|
+
|
|
96
|
+
AQ = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Q, w1, w2)
|
|
97
|
+
QTAQ = Q.T @ AQ
|
|
98
|
+
|
|
99
|
+
W = (QTAQ + QTAQ.T) / 2.0
|
|
100
|
+
|
|
101
|
+
# compute new L and Q
|
|
102
|
+
try:
|
|
103
|
+
L_prime, S = torch_linalg.eigh(W, retry_float64=True)
|
|
104
|
+
except torch.linalg.LinAlgError:
|
|
105
|
+
return L1, Q1
|
|
106
|
+
|
|
107
|
+
L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
|
|
108
|
+
|
|
109
|
+
if L_prime is None or S is None:
|
|
110
|
+
return L1, Q1
|
|
111
|
+
|
|
112
|
+
return L_prime, Q @ S
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# def adanystrom_update2(
|
|
116
|
+
# L1: torch.Tensor,
|
|
117
|
+
# Q1: torch.Tensor,
|
|
118
|
+
# v2: torch.Tensor,
|
|
119
|
+
# w1: float,
|
|
120
|
+
# w2: float,
|
|
121
|
+
# rank: int,
|
|
122
|
+
# ):
|
|
123
|
+
# def A_mm(X):
|
|
124
|
+
# return weighted_eigen_plus_rank1_mm(L1=L1, Q1=Q1, v2=v2, B=X, w1=w1, w2=w2)
|
|
125
|
+
|
|
126
|
+
# return nystrom_approximation(A_mm, A_mm=A_mm, ndim=v2.numel(), rank=rank, device=L1.device, dtype=L1.dtype)
|
|
127
|
+
|
|
128
|
+
class AdaNystrom(TensorTransform):
|
|
129
|
+
"""Adagrad/RMSprop/Adam with Nyström-approximated covariance matrix.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
rank (_type_): rank of Nyström approximation.
|
|
133
|
+
w1 (float, optional): weight of current covariance matrix. Defaults to 0.95.
|
|
134
|
+
w2 (float, optional): weight of new gradient in covariance matrix. Defaults to 0.05.
|
|
135
|
+
oversampling (int, optional): number of extra random vectors (top rank eigenvalues are kept). Defaults to 10.
|
|
136
|
+
eig_tol (float, optional):
|
|
137
|
+
removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
|
|
138
|
+
damping (float, optional):
|
|
139
|
+
added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
|
|
140
|
+
rdamping (float, optional):
|
|
141
|
+
added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
|
|
142
|
+
mm_tol (float, optional):
|
|
143
|
+
removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
|
|
144
|
+
mm_truncate (int | None, optional):
|
|
145
|
+
uses top k eigenvalues to compute the update. Defaults to None.
|
|
146
|
+
mm_damping (float, optional):
|
|
147
|
+
added to eigenvalues when computing the update. Defaults to 1e-4.
|
|
148
|
+
mm_rdamping (float, optional):
|
|
149
|
+
added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
|
|
150
|
+
id_reg (float, optional):
|
|
151
|
+
multiplier to identity matrix added to preconditioner before computing update
|
|
152
|
+
If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
|
|
153
|
+
This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
|
|
154
|
+
concat_params (bool, optional):
|
|
155
|
+
whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
|
|
156
|
+
update_freq (int, optional): update frequency. Defaults to 1.
|
|
157
|
+
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
158
|
+
"""
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
rank:int = 100,
|
|
162
|
+
beta=0.95,
|
|
163
|
+
oversampling: int = 10,
|
|
164
|
+
eig_tol: float | None = 1e-32,
|
|
165
|
+
damping: float = 0,
|
|
166
|
+
rdamping: float = 0,
|
|
167
|
+
mm_tol: float = 0,
|
|
168
|
+
mm_truncate: int | None = None,
|
|
169
|
+
mm_damping: float = 0,
|
|
170
|
+
mm_rdamping: float = 0,
|
|
171
|
+
id_reg: float | None = None,
|
|
172
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
173
|
+
eigenbasis_optimizer: LREOptimizerBase | None = None,
|
|
174
|
+
orthogonalize_interval: int | None = 100,
|
|
175
|
+
|
|
176
|
+
concat_params: bool = True,
|
|
177
|
+
update_freq: int = 1,
|
|
178
|
+
inner: Chainable | None = None,
|
|
179
|
+
):
|
|
180
|
+
defaults = locals().copy()
|
|
181
|
+
for k in ["self", "concat_params", "inner", "update_freq"]:
|
|
182
|
+
del defaults[k]
|
|
183
|
+
|
|
184
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
|
|
185
|
+
|
|
186
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
187
|
+
state["step"] = state.get("step", 0) + 1
|
|
188
|
+
rank = setting["rank"]
|
|
189
|
+
device = tensor.device
|
|
190
|
+
dtype = tensor.dtype
|
|
191
|
+
beta = setting["beta"]
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
if "L" not in state:
|
|
195
|
+
# use just tensor and zero L and Q with zero weight
|
|
196
|
+
|
|
197
|
+
L, Q = adanystrom_update(
|
|
198
|
+
L1=torch.zeros(rank, device=device, dtype=dtype),
|
|
199
|
+
Q1=torch.zeros((tensor.numel(), rank), device=device, dtype=dtype),
|
|
200
|
+
v2=tensor.ravel(),
|
|
201
|
+
w1=0,
|
|
202
|
+
w2=1-beta,
|
|
203
|
+
rank=rank,
|
|
204
|
+
oversampling_p=setting["oversampling"],
|
|
205
|
+
eig_tol=setting["eig_tol"],
|
|
206
|
+
damping=setting["damping"],
|
|
207
|
+
rdamping=setting["rdamping"],
|
|
208
|
+
orthogonalize_method=setting["orthogonalize_method"],
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
state["L"] = state["L_reg"] = L
|
|
212
|
+
state["Q"] = state["Q_reg"] = Q
|
|
213
|
+
|
|
214
|
+
else:
|
|
215
|
+
L = state["L"]
|
|
216
|
+
Q = state["Q"]
|
|
217
|
+
|
|
218
|
+
w1 = beta
|
|
219
|
+
w2 = 1 - w1
|
|
220
|
+
|
|
221
|
+
# compute new factors (this function truncates them)
|
|
222
|
+
L_new, Q_new = adanystrom_update(
|
|
223
|
+
L1=L,
|
|
224
|
+
Q1=Q,
|
|
225
|
+
v2=tensor.ravel(),
|
|
226
|
+
w1=w1,
|
|
227
|
+
w2=w2,
|
|
228
|
+
rank=rank,
|
|
229
|
+
oversampling_p=setting["oversampling"],
|
|
230
|
+
eig_tol=setting["eig_tol"],
|
|
231
|
+
damping=setting["damping"],
|
|
232
|
+
rdamping=setting["rdamping"],
|
|
233
|
+
orthogonalize_method=setting["orthogonalize_method"],
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
_eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
|
|
237
|
+
|
|
238
|
+
except torch.linalg.LinAlgError:
|
|
239
|
+
pass
|
|
240
|
+
|
|
241
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
242
|
+
if "L_reg" not in state:
|
|
243
|
+
return tensor.clip(-0.1, 0.1)
|
|
244
|
+
|
|
245
|
+
if "eigenbasis_state" not in state:
|
|
246
|
+
state["eigenbasis_state"] = {}
|
|
247
|
+
|
|
248
|
+
return eigengrad_apply(
|
|
249
|
+
tensor=tensor,
|
|
250
|
+
L_reg = state["L_reg"],
|
|
251
|
+
Q_reg = state["Q_reg"],
|
|
252
|
+
beta = setting["beta"],
|
|
253
|
+
step = state["step"],
|
|
254
|
+
debias = True,
|
|
255
|
+
id_reg = setting["id_reg"],
|
|
256
|
+
eigenbasis_optimizer = setting["eigenbasis_optimizer"],
|
|
257
|
+
eigenbasis_state = state["eigenbasis_state"]
|
|
258
|
+
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from torchzero.core import Chainable, TensorTransform
|
|
7
|
+
from torchzero.linalg import matrix_power_eigh, torch_linalg, orthogonalize, OrthogonalizeMethod, regularize_eigh
|
|
8
|
+
from torchzero.utils import TensorList, vec_to_tensors_
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def update_subspace_preconditioner_(
|
|
12
|
+
grad: torch.Tensor, # store grads and basis as vectors for matmul
|
|
13
|
+
basis: torch.Tensor, # ndim, k
|
|
14
|
+
accumulator_: torch.Tensor, # k, k
|
|
15
|
+
beta: float | None,
|
|
16
|
+
):
|
|
17
|
+
projected = basis.T @ grad # k
|
|
18
|
+
outer = torch.outer(projected, projected)
|
|
19
|
+
|
|
20
|
+
if beta is None: accumulator_.add_(outer)
|
|
21
|
+
else: accumulator_.lerp_(outer, 1-beta)
|
|
22
|
+
|
|
23
|
+
# yeah so I can also run subspace opts in this basis
|
|
24
|
+
def apply_subspace_preconditioner(
|
|
25
|
+
tensor: torch.Tensor,
|
|
26
|
+
basis: torch.Tensor, # ndim, k
|
|
27
|
+
accumulator: torch.Tensor,
|
|
28
|
+
tol: float,
|
|
29
|
+
truncate: int | None,
|
|
30
|
+
damping: float,
|
|
31
|
+
rdamping: float,
|
|
32
|
+
):
|
|
33
|
+
L, Q = torch_linalg.eigh(accumulator, retry_float64=True)
|
|
34
|
+
L, Q = regularize_eigh(L=L, Q=Q, truncate=truncate, tol=tol, damping=damping, rdamping=rdamping)
|
|
35
|
+
|
|
36
|
+
if L is None or Q is None:
|
|
37
|
+
return tensor.clip(-0.1, 0.1)
|
|
38
|
+
|
|
39
|
+
preconditioner = (Q * L.rsqrt().unsqueeze(-2)) @ Q.mH
|
|
40
|
+
|
|
41
|
+
tensor_projected = basis.T @ tensor # k
|
|
42
|
+
update_projected = preconditioner @ tensor_projected # k
|
|
43
|
+
return basis @ update_projected # d
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CommonDirectionsWhiten(TensorTransform):
|
|
47
|
+
"""Whitens in subspace spanned by history of gradient differences.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
beta - for preconditioner itself in the basis.
|
|
51
|
+
basis_beta - how much basis is allowed to change.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
k: int = 100,
|
|
57
|
+
beta: float | None = 0.95,
|
|
58
|
+
basis_beta=0.95,
|
|
59
|
+
tol: float = 1e-7,
|
|
60
|
+
truncate: int | None = None,
|
|
61
|
+
damping: float = 1e-4,
|
|
62
|
+
rdamping: float = 0,
|
|
63
|
+
basis_type: Literal["gradients", "differences"] = "differences",
|
|
64
|
+
orthogonalize_method: OrthogonalizeMethod | None = 'newtonschulz',
|
|
65
|
+
|
|
66
|
+
concat_params: bool = True,
|
|
67
|
+
inner: Chainable | None = None,
|
|
68
|
+
):
|
|
69
|
+
defaults = locals().copy()
|
|
70
|
+
for key in ["self", "inner", "concat_params"]:
|
|
71
|
+
del defaults[key]
|
|
72
|
+
|
|
73
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
74
|
+
|
|
75
|
+
@torch.no_grad
|
|
76
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
77
|
+
g = tensor.ravel()
|
|
78
|
+
k = setting['k']
|
|
79
|
+
beta = setting['beta']
|
|
80
|
+
basis_beta = setting['basis_beta']
|
|
81
|
+
step = state.get("step", 0)
|
|
82
|
+
state["step"] = step + 1
|
|
83
|
+
|
|
84
|
+
# initialize history
|
|
85
|
+
if 'history' not in state:
|
|
86
|
+
state['history'] = deque(maxlen=k)
|
|
87
|
+
state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
88
|
+
state['basis'] = torch.zeros(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
89
|
+
|
|
90
|
+
history: deque = state['history']
|
|
91
|
+
accumulator = state['accumulator']
|
|
92
|
+
basis = state['basis']
|
|
93
|
+
history.append(g)
|
|
94
|
+
|
|
95
|
+
# stack history to new basis term, if history isn't full, fill with random vecs
|
|
96
|
+
if len(history) < k:
|
|
97
|
+
basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
98
|
+
history_basis = torch.stack(tuple(history), -1)
|
|
99
|
+
basis_t[:, -len(history):] = history_basis
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
basis_t = torch.stack(tuple(history), -1)
|
|
103
|
+
|
|
104
|
+
# in this case basis uses differences in gradients except last entry is the gradient
|
|
105
|
+
if setting["basis_type"] == "differences":
|
|
106
|
+
basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
|
|
107
|
+
|
|
108
|
+
# normalize or orthonormalize new basis term
|
|
109
|
+
if setting["orthogonalize_method"] is not None:
|
|
110
|
+
basis_t = orthogonalize(basis_t, method = setting["orthogonalize_method"])
|
|
111
|
+
else:
|
|
112
|
+
basis_t = (basis_t - basis_t.mean()) / basis_t.std().clip(min=torch.finfo(g.dtype).tiny * 2)
|
|
113
|
+
|
|
114
|
+
# lerp basis
|
|
115
|
+
basis.lerp_(basis_t, 1-basis_beta)
|
|
116
|
+
basis = basis / (1 - basis_beta ** (step+1)) # correct bias on basis EMA
|
|
117
|
+
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
118
|
+
|
|
119
|
+
@torch.no_grad
|
|
120
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
121
|
+
g = tensor.ravel()
|
|
122
|
+
|
|
123
|
+
basis = state['basis']
|
|
124
|
+
accumulator = state['accumulator']
|
|
125
|
+
step = state["step"]
|
|
126
|
+
accumulator = accumulator / (1 - setting["beta"] ** (step+1)) # correct bias on accumulator EMA
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
preconditioned = apply_subspace_preconditioner(
|
|
130
|
+
g,
|
|
131
|
+
basis,
|
|
132
|
+
accumulator,
|
|
133
|
+
tol=setting["tol"],
|
|
134
|
+
truncate=setting["truncate"],
|
|
135
|
+
damping=setting["damping"],
|
|
136
|
+
rdamping=setting["rdamping"],
|
|
137
|
+
)
|
|
138
|
+
except torch.linalg.LinAlgError:
|
|
139
|
+
preconditioned = g.clip(-0.1, 0.1)
|
|
140
|
+
|
|
141
|
+
return preconditioned.view_as(tensor)
|
|
142
|
+
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import TensorTransform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_states
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def coordinate_momentum_(
|
|
8
|
+
tensors: TensorList,
|
|
9
|
+
velocity_: TensorList,
|
|
10
|
+
p: float | NumberList,
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
sets `velocity_` to p% random values from `tensors`.
|
|
14
|
+
|
|
15
|
+
Returns `velocity_`
|
|
16
|
+
"""
|
|
17
|
+
mask = tensors.bernoulli_like(p).as_bool()
|
|
18
|
+
velocity_.masked_set_(mask, tensors)
|
|
19
|
+
return velocity_
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CoordinateMomentum(TensorTransform):
|
|
23
|
+
"""Maintains a momentum buffer, on each step each value in the buffer has ``p`` chance to be updated with the new value.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
p (float, optional): _description_. Defaults to 0.1.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, p: float = 0.1):
|
|
29
|
+
defaults = dict(p=p)
|
|
30
|
+
super().__init__(defaults)
|
|
31
|
+
|
|
32
|
+
@torch.no_grad
|
|
33
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
34
|
+
p = NumberList(s['p'] for s in settings)
|
|
35
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
36
|
+
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|