torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from contextlib import nullcontext
|
|
2
2
|
import torch
|
|
3
|
-
from ...utils import TensorList, NumberList
|
|
4
|
-
from ...core import
|
|
3
|
+
from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
|
|
4
|
+
from ...core import Transform
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
class SAM(
|
|
7
|
+
class SAM(Transform):
|
|
8
8
|
"""Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
|
|
9
9
|
|
|
10
10
|
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
@@ -22,50 +22,51 @@ class SAM(Module):
|
|
|
22
22
|
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
23
23
|
asam (bool, optional):
|
|
24
24
|
enables ASAM variant which makes perturbation relative to weight magnitudes.
|
|
25
|
-
ASAM requires a much larger
|
|
26
|
-
The
|
|
27
|
-
it has larger
|
|
25
|
+
ASAM requires a much larger ``rho``, like 0.5 or 1.
|
|
26
|
+
The ``tz.m.ASAM`` class is idential to setting this argument to True, but
|
|
27
|
+
it has larger ``rho`` by default.
|
|
28
28
|
|
|
29
|
-
Examples:
|
|
30
|
-
SAM-SGD:
|
|
29
|
+
### Examples:
|
|
31
30
|
|
|
32
|
-
|
|
31
|
+
SAM-SGD:
|
|
33
32
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
33
|
+
```py
|
|
34
|
+
opt = tz.Modular(
|
|
35
|
+
model.parameters(),
|
|
36
|
+
tz.m.SAM(),
|
|
37
|
+
tz.m.LR(1e-2)
|
|
38
|
+
)
|
|
39
|
+
```
|
|
39
40
|
|
|
40
|
-
|
|
41
|
+
SAM-Adam:
|
|
41
42
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
43
|
+
```
|
|
44
|
+
opt = tz.Modular(
|
|
45
|
+
model.parameters(),
|
|
46
|
+
tz.m.SAM(),
|
|
47
|
+
tz.m.Adam(),
|
|
48
|
+
tz.m.LR(1e-2)
|
|
49
|
+
)
|
|
50
|
+
```
|
|
50
51
|
|
|
51
52
|
References:
|
|
52
|
-
Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.
|
|
53
|
+
[Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.](https://arxiv.org/abs/2010.01412#page=3.16)
|
|
53
54
|
"""
|
|
54
55
|
def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
|
|
55
56
|
defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
|
|
56
57
|
super().__init__(defaults)
|
|
57
58
|
|
|
58
59
|
@torch.no_grad
|
|
59
|
-
def
|
|
60
|
+
def update_states(self, objective, states, settings):
|
|
60
61
|
|
|
61
|
-
params =
|
|
62
|
-
closure =
|
|
63
|
-
zero_grad =
|
|
62
|
+
params = objective.params
|
|
63
|
+
closure = objective.closure
|
|
64
|
+
zero_grad = objective.zero_grad
|
|
64
65
|
if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
|
|
65
|
-
p, rho =
|
|
66
|
-
|
|
67
|
-
eps =
|
|
68
|
-
asam =
|
|
66
|
+
p, rho = unpack_dicts(settings, 'p', 'rho', cls=NumberList)
|
|
67
|
+
fs = settings[0]
|
|
68
|
+
eps = fs['eps']
|
|
69
|
+
asam = fs['asam']
|
|
69
70
|
|
|
70
71
|
# 1/p + 1/q = 1
|
|
71
72
|
# okay, authors of SAM paper, I will manually solve your equation
|
|
@@ -123,8 +124,7 @@ class SAM(Module):
|
|
|
123
124
|
|
|
124
125
|
return sam_loss
|
|
125
126
|
|
|
126
|
-
|
|
127
|
-
return var
|
|
127
|
+
objective.closure = sam_closure
|
|
128
128
|
|
|
129
129
|
# different class because defaults for SAM are bad for ASAM
|
|
130
130
|
class ASAM(SAM):
|
|
@@ -136,7 +136,7 @@ class ASAM(SAM):
|
|
|
136
136
|
This implementation modifies the closure to return loss and calculate gradients
|
|
137
137
|
of the SAM objective. All modules after this will use the modified objective.
|
|
138
138
|
|
|
139
|
-
|
|
139
|
+
Note:
|
|
140
140
|
This module requires a closure passed to the optimizer step,
|
|
141
141
|
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
142
142
|
|
|
@@ -144,20 +144,30 @@ class ASAM(SAM):
|
|
|
144
144
|
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
145
145
|
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
146
146
|
|
|
147
|
-
Examples:
|
|
148
|
-
|
|
147
|
+
### Examples:
|
|
148
|
+
|
|
149
|
+
ASAM-SGD:
|
|
149
150
|
|
|
150
|
-
|
|
151
|
+
```py
|
|
152
|
+
opt = tz.Modular(
|
|
153
|
+
model.parameters(),
|
|
154
|
+
tz.m.ASAM(),
|
|
155
|
+
tz.m.LR(1e-2)
|
|
156
|
+
)
|
|
157
|
+
```
|
|
151
158
|
|
|
152
|
-
|
|
153
|
-
model.parameters(),
|
|
154
|
-
tz.m.ASAM(),
|
|
155
|
-
tz.m.Adam(),
|
|
156
|
-
tz.m.LR(1e-2)
|
|
157
|
-
)
|
|
159
|
+
ASAM-Adam:
|
|
158
160
|
|
|
161
|
+
```
|
|
162
|
+
opt = tz.Modular(
|
|
163
|
+
model.parameters(),
|
|
164
|
+
tz.m.ASAM(),
|
|
165
|
+
tz.m.Adam(),
|
|
166
|
+
tz.m.LR(1e-2)
|
|
167
|
+
)
|
|
168
|
+
```
|
|
159
169
|
References:
|
|
160
|
-
Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July).
|
|
170
|
+
[Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.](https://arxiv.org/abs/2102.11600)
|
|
161
171
|
"""
|
|
162
172
|
def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
|
|
163
173
|
super().__init__(rho=rho, p=p, eps=eps, asam=True)
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
from functools import partial
|
|
2
|
+
|
|
4
3
|
import numpy as np
|
|
5
4
|
import torch
|
|
6
5
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
from ...linalg.matrix_power import MatrixPowerMethod, matrix_power as _matrix_power
|
|
9
8
|
from ...utils import set_storage_
|
|
10
9
|
|
|
11
10
|
|
|
@@ -14,10 +13,11 @@ def update_shampoo_preconditioner_(
|
|
|
14
13
|
accumulators_: list[torch.Tensor | None],
|
|
15
14
|
preconditioners_: list[torch.Tensor | None],
|
|
16
15
|
step: int,
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
precond_freq: int,
|
|
17
|
+
matrix_power: float | None,
|
|
19
18
|
beta: float | None,
|
|
20
|
-
reg: float
|
|
19
|
+
reg: float,
|
|
20
|
+
matrix_power_method: MatrixPowerMethod,
|
|
21
21
|
):
|
|
22
22
|
for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
|
|
23
23
|
if accumulator is None: continue
|
|
@@ -27,22 +27,20 @@ def update_shampoo_preconditioner_(
|
|
|
27
27
|
if beta is None: accumulator.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
28
28
|
else: accumulator.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
29
29
|
|
|
30
|
-
if step %
|
|
31
|
-
matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
|
|
30
|
+
if step % precond_freq == 0:
|
|
32
31
|
if reg != 0:
|
|
33
32
|
accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
|
|
34
|
-
set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
|
|
35
33
|
|
|
34
|
+
if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
|
|
35
|
+
set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
|
|
36
36
|
|
|
37
37
|
def apply_shampoo_preconditioner(
|
|
38
38
|
tensor: torch.Tensor,
|
|
39
39
|
preconditioners_: list[torch.Tensor | None],
|
|
40
|
-
decay: float | None,
|
|
41
40
|
):
|
|
42
41
|
for i, preconditioner in enumerate(preconditioners_):
|
|
43
42
|
if preconditioner is None: continue
|
|
44
43
|
tensor = torch.tensordot(tensor, preconditioner, ([0], [0])) # pyright:ignore[reportArgumentType]
|
|
45
|
-
if decay is not None: preconditioner.mul_(decay)
|
|
46
44
|
return tensor
|
|
47
45
|
|
|
48
46
|
|
|
@@ -50,9 +48,8 @@ def update_diagonal_(grad: torch.Tensor, diagonal_accumulator_: torch.Tensor, be
|
|
|
50
48
|
if beta is None: diagonal_accumulator_.add_(grad.pow(2))
|
|
51
49
|
else: diagonal_accumulator_.mul_(beta).addcmul_(grad, grad, value=1-beta)
|
|
52
50
|
|
|
53
|
-
def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor,
|
|
51
|
+
def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, eps: float):
|
|
54
52
|
grad_.div_(diagonal_accumulator_.sqrt() + eps)
|
|
55
|
-
if decay is not None: diagonal_accumulator_.mul_(decay)
|
|
56
53
|
return grad_
|
|
57
54
|
|
|
58
55
|
def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
|
|
@@ -86,144 +83,141 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
|
|
|
86
83
|
return tensor.permute(*np.argsort(sort_idxs).tolist())
|
|
87
84
|
|
|
88
85
|
|
|
89
|
-
class Shampoo(
|
|
86
|
+
class Shampoo(TensorTransform):
|
|
90
87
|
"""Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
|
|
91
88
|
|
|
92
|
-
|
|
89
|
+
Notes:
|
|
93
90
|
Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
|
|
94
91
|
|
|
95
|
-
|
|
96
|
-
Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
|
|
92
|
+
Shampoo is a very computationally expensive optimizer, increase ``update_freq`` if it is too slow.
|
|
97
93
|
|
|
98
|
-
|
|
99
|
-
SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
|
|
94
|
+
SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as ``tz.m.SOAP``.
|
|
100
95
|
|
|
101
96
|
Args:
|
|
102
|
-
decay (float | None, optional): slowly decays preconditioners. Defaults to None.
|
|
103
|
-
beta (float | None, optional):
|
|
104
|
-
if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
105
97
|
update_freq (int, optional): preconditioner update frequency. Defaults to 10.
|
|
106
|
-
|
|
98
|
+
matrix_power (float | None, optional): overrides matrix exponent. By default uses ``-1/grad.ndim``. Defaults to None.
|
|
107
99
|
merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
|
|
108
|
-
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to
|
|
100
|
+
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 10_000.
|
|
109
101
|
precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
|
|
110
102
|
adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
|
|
103
|
+
matrix_power_method (MatrixPowerMethod, optional): how to compute matrix power.
|
|
104
|
+
beta (float | None, optional):
|
|
105
|
+
if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
111
106
|
inner (Chainable | None, optional):
|
|
112
107
|
module applied after updating preconditioners and before applying preconditioning.
|
|
113
108
|
For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
|
|
114
109
|
Defaults to None.
|
|
115
110
|
|
|
116
111
|
Examples:
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
112
|
+
Shampoo grafted to Adam
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
opt = tz.Modular(
|
|
116
|
+
model.parameters(),
|
|
117
|
+
tz.m.GraftModules(
|
|
118
|
+
direction = tz.m.Shampoo(),
|
|
119
|
+
magnitude = tz.m.Adam(),
|
|
120
|
+
),
|
|
121
|
+
tz.m.LR(1e-3)
|
|
122
|
+
)
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
Adam with Shampoo preconditioner
|
|
126
|
+
|
|
127
|
+
```python
|
|
128
|
+
opt = tz.Modular(
|
|
129
|
+
model.parameters(),
|
|
130
|
+
tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
131
|
+
tz.m.Debias(0.9, 0.999),
|
|
132
|
+
tz.m.LR(1e-3)
|
|
133
|
+
)
|
|
134
|
+
```
|
|
140
135
|
"""
|
|
141
136
|
def __init__(
|
|
142
137
|
self,
|
|
143
|
-
decay: float | None = None,
|
|
144
|
-
beta: float | None = None,
|
|
145
138
|
reg: float = 1e-12,
|
|
146
|
-
|
|
147
|
-
|
|
139
|
+
precond_freq: int = 10,
|
|
140
|
+
matrix_power: float | None = None,
|
|
148
141
|
merge_small: bool = True,
|
|
149
|
-
max_dim: int =
|
|
142
|
+
max_dim: int = 10_000,
|
|
150
143
|
precondition_1d: bool = True,
|
|
151
144
|
adagrad_eps: float = 1e-8,
|
|
145
|
+
matrix_power_method: MatrixPowerMethod = "eigh_abs",
|
|
146
|
+
beta: float | None = None,
|
|
147
|
+
beta_debias: bool = True,
|
|
148
|
+
|
|
152
149
|
inner: Chainable | None = None,
|
|
153
150
|
):
|
|
154
|
-
defaults =
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
def
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
state['step'] += 1
|
|
228
|
-
|
|
229
|
-
return tensors
|
|
151
|
+
defaults = locals().copy()
|
|
152
|
+
del defaults['self'], defaults["inner"]
|
|
153
|
+
|
|
154
|
+
super().__init__(defaults, inner=inner)
|
|
155
|
+
|
|
156
|
+
@torch.no_grad
|
|
157
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
158
|
+
if setting["merge_small"]:
|
|
159
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
160
|
+
|
|
161
|
+
if tensor.ndim <= 1 and not setting["precondition_1d"]:
|
|
162
|
+
state["accumulators"] = []
|
|
163
|
+
|
|
164
|
+
else:
|
|
165
|
+
max_dim = setting["max_dim"]
|
|
166
|
+
state['accumulators'] = [
|
|
167
|
+
torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
|
|
168
|
+
]
|
|
169
|
+
state['preconditioners'] = [
|
|
170
|
+
torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
# either scalar parameter, 1d with precondition_1d=False, or too big, then diagonal preconditioner is used.
|
|
174
|
+
if len([i is not None for i in state['accumulators']]) == 0:
|
|
175
|
+
state['diagonal_accumulator'] = torch.zeros_like(tensor)
|
|
176
|
+
|
|
177
|
+
state['step'] = 0
|
|
178
|
+
state["num_GTG"] = 0
|
|
179
|
+
|
|
180
|
+
@torch.no_grad
|
|
181
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
182
|
+
if setting["merge_small"]:
|
|
183
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
184
|
+
|
|
185
|
+
if 'diagonal_accumulator' in state:
|
|
186
|
+
update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
|
|
187
|
+
else:
|
|
188
|
+
update_shampoo_preconditioner_(
|
|
189
|
+
tensor,
|
|
190
|
+
accumulators_=state['accumulators'],
|
|
191
|
+
preconditioners_=state['preconditioners'],
|
|
192
|
+
step=state['step'],
|
|
193
|
+
precond_freq=setting["precond_freq"],
|
|
194
|
+
matrix_power=setting["matrix_power"],
|
|
195
|
+
beta=setting["beta"],
|
|
196
|
+
reg=setting["reg"],
|
|
197
|
+
matrix_power_method=setting["matrix_power_method"],
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if state["step"] % setting["precond_freq"] == 0:
|
|
201
|
+
state["num_GTG"] += 1
|
|
202
|
+
|
|
203
|
+
state["step"] += 1
|
|
204
|
+
|
|
205
|
+
@torch.no_grad
|
|
206
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
207
|
+
if setting["merge_small"]:
|
|
208
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
209
|
+
|
|
210
|
+
if 'diagonal_accumulator' in state:
|
|
211
|
+
dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
|
|
212
|
+
else:
|
|
213
|
+
dir = apply_shampoo_preconditioner(tensor, preconditioners_=state['preconditioners'])
|
|
214
|
+
|
|
215
|
+
if setting["merge_small"]:
|
|
216
|
+
dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
|
|
217
|
+
|
|
218
|
+
if setting['beta_debias'] and setting["beta"] is not None:
|
|
219
|
+
bias_correction = 1 - (setting["beta"] ** state["num_GTG"])
|
|
220
|
+
dir *= bias_correction ** 0.5
|
|
221
|
+
|
|
222
|
+
return dir
|
|
223
|
+
|