torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
from functools import partial
|
|
1
|
+
from collections.abc import Sequence, Iterable
|
|
2
|
+
|
|
4
3
|
import numpy as np
|
|
5
4
|
import torch
|
|
6
5
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
from ...linalg.matrix_power import MatrixPowerMethod, matrix_power as _matrix_power
|
|
9
8
|
from ...utils import set_storage_
|
|
10
9
|
|
|
11
10
|
|
|
@@ -14,10 +13,11 @@ def update_shampoo_preconditioner_(
|
|
|
14
13
|
accumulators_: list[torch.Tensor | None],
|
|
15
14
|
preconditioners_: list[torch.Tensor | None],
|
|
16
15
|
step: int,
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
precond_freq: int,
|
|
17
|
+
matrix_power: float | None,
|
|
19
18
|
beta: float | None,
|
|
20
|
-
reg: float
|
|
19
|
+
reg: float,
|
|
20
|
+
matrix_power_method: MatrixPowerMethod,
|
|
21
21
|
):
|
|
22
22
|
for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
|
|
23
23
|
if accumulator is None: continue
|
|
@@ -27,22 +27,20 @@ def update_shampoo_preconditioner_(
|
|
|
27
27
|
if beta is None: accumulator.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
28
28
|
else: accumulator.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
29
29
|
|
|
30
|
-
if step %
|
|
31
|
-
matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
|
|
30
|
+
if step % precond_freq == 0:
|
|
32
31
|
if reg != 0:
|
|
33
32
|
accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
|
|
34
|
-
set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
|
|
35
33
|
|
|
34
|
+
if matrix_power is None: matrix_power = -1 / max(grad.ndim, 2)
|
|
35
|
+
set_storage_(preconditioner, _matrix_power(accumulator, matrix_power, method=matrix_power_method))
|
|
36
36
|
|
|
37
37
|
def apply_shampoo_preconditioner(
|
|
38
38
|
tensor: torch.Tensor,
|
|
39
39
|
preconditioners_: list[torch.Tensor | None],
|
|
40
|
-
decay: float | None,
|
|
41
40
|
):
|
|
42
41
|
for i, preconditioner in enumerate(preconditioners_):
|
|
43
42
|
if preconditioner is None: continue
|
|
44
43
|
tensor = torch.tensordot(tensor, preconditioner, ([0], [0])) # pyright:ignore[reportArgumentType]
|
|
45
|
-
if decay is not None: preconditioner.mul_(decay)
|
|
46
44
|
return tensor
|
|
47
45
|
|
|
48
46
|
|
|
@@ -50,9 +48,8 @@ def update_diagonal_(grad: torch.Tensor, diagonal_accumulator_: torch.Tensor, be
|
|
|
50
48
|
if beta is None: diagonal_accumulator_.add_(grad.pow(2))
|
|
51
49
|
else: diagonal_accumulator_.mul_(beta).addcmul_(grad, grad, value=1-beta)
|
|
52
50
|
|
|
53
|
-
def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor,
|
|
51
|
+
def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, eps: float):
|
|
54
52
|
grad_.div_(diagonal_accumulator_.sqrt() + eps)
|
|
55
|
-
if decay is not None: diagonal_accumulator_.mul_(decay)
|
|
56
53
|
return grad_
|
|
57
54
|
|
|
58
55
|
def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
|
|
@@ -85,145 +82,167 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
|
|
|
85
82
|
tensor = tensor.unflatten(0, flat_sizes)
|
|
86
83
|
return tensor.permute(*np.argsort(sort_idxs).tolist())
|
|
87
84
|
|
|
85
|
+
def diagonal_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor]):
|
|
86
|
+
"""computes number of parameters"""
|
|
87
|
+
if isinstance(params, torch.nn.Module): params = params.parameters()
|
|
88
|
+
if isinstance(params, torch.Tensor): params = [params,]
|
|
89
|
+
params = list(params)
|
|
90
|
+
return sum(p.numel() for p in params)
|
|
91
|
+
|
|
92
|
+
def kronecker_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor], merge_small:bool=True, max_dim:int=10_000):
|
|
93
|
+
"""computes total size of tensors required to store shampoo preconditioner"""
|
|
94
|
+
if isinstance(params, torch.nn.Module): params = params.parameters()
|
|
95
|
+
if isinstance(params, torch.Tensor): params = [params,]
|
|
96
|
+
params = list(params)
|
|
97
|
+
|
|
98
|
+
memory = 0
|
|
99
|
+
for p in params:
|
|
100
|
+
if merge_small:
|
|
101
|
+
p, _, _ = _merge_small_dims(p, max_dim)
|
|
102
|
+
for dim in p.size():
|
|
103
|
+
if dim > max_dim: memory += dim
|
|
104
|
+
else: memory += dim**2
|
|
105
|
+
|
|
106
|
+
return memory
|
|
107
|
+
|
|
88
108
|
|
|
89
|
-
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class Shampoo(TensorTransform):
|
|
90
112
|
"""Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
|
|
91
113
|
|
|
92
|
-
|
|
114
|
+
Notes:
|
|
93
115
|
Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
|
|
94
116
|
|
|
95
|
-
|
|
96
|
-
Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
|
|
117
|
+
Shampoo is a very computationally expensive optimizer, increase ``update_freq`` if it is too slow.
|
|
97
118
|
|
|
98
|
-
|
|
99
|
-
SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
|
|
119
|
+
SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as ``tz.m.SOAP``.
|
|
100
120
|
|
|
101
121
|
Args:
|
|
102
|
-
decay (float | None, optional): slowly decays preconditioners. Defaults to None.
|
|
103
|
-
beta (float | None, optional):
|
|
104
|
-
if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
105
122
|
update_freq (int, optional): preconditioner update frequency. Defaults to 10.
|
|
106
|
-
|
|
123
|
+
matrix_power (float | None, optional): overrides matrix exponent. By default uses ``-1/grad.ndim``. Defaults to None.
|
|
107
124
|
merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
|
|
108
|
-
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to
|
|
125
|
+
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 10_000.
|
|
109
126
|
precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
|
|
110
127
|
adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
|
|
128
|
+
matrix_power_method (MatrixPowerMethod, optional): how to compute matrix power.
|
|
129
|
+
beta (float | None, optional):
|
|
130
|
+
if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
111
131
|
inner (Chainable | None, optional):
|
|
112
132
|
module applied after updating preconditioners and before applying preconditioning.
|
|
113
133
|
For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
|
|
114
134
|
Defaults to None.
|
|
115
135
|
|
|
116
136
|
Examples:
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
137
|
+
Shampoo grafted to Adam
|
|
138
|
+
|
|
139
|
+
```python
|
|
140
|
+
opt = tz.Optimizer(
|
|
141
|
+
model.parameters(),
|
|
142
|
+
tz.m.GraftModules(
|
|
143
|
+
direction = tz.m.Shampoo(),
|
|
144
|
+
magnitude = tz.m.Adam(),
|
|
145
|
+
),
|
|
146
|
+
tz.m.LR(1e-3)
|
|
147
|
+
)
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
Adam with Shampoo preconditioner
|
|
151
|
+
|
|
152
|
+
```python
|
|
153
|
+
opt = tz.Optimizer(
|
|
154
|
+
model.parameters(),
|
|
155
|
+
tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
156
|
+
tz.m.Debias(0.9, 0.999),
|
|
157
|
+
tz.m.LR(1e-3)
|
|
158
|
+
)
|
|
159
|
+
```
|
|
140
160
|
"""
|
|
141
161
|
def __init__(
|
|
142
162
|
self,
|
|
143
|
-
decay: float | None = None,
|
|
144
|
-
beta: float | None = None,
|
|
145
163
|
reg: float = 1e-12,
|
|
146
|
-
|
|
147
|
-
|
|
164
|
+
precond_freq: int = 10,
|
|
165
|
+
matrix_power: float | None = None,
|
|
148
166
|
merge_small: bool = True,
|
|
149
|
-
max_dim: int =
|
|
167
|
+
max_dim: int = 10_000,
|
|
150
168
|
precondition_1d: bool = True,
|
|
151
169
|
adagrad_eps: float = 1e-8,
|
|
170
|
+
matrix_power_method: MatrixPowerMethod = "eigh_abs",
|
|
171
|
+
beta: float | None = None,
|
|
172
|
+
beta_debias: bool = True,
|
|
173
|
+
|
|
152
174
|
inner: Chainable | None = None,
|
|
153
175
|
):
|
|
154
|
-
defaults =
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
def
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
state['step'] += 1
|
|
228
|
-
|
|
229
|
-
return tensors
|
|
176
|
+
defaults = locals().copy()
|
|
177
|
+
del defaults['self'], defaults["inner"]
|
|
178
|
+
|
|
179
|
+
super().__init__(defaults, inner=inner)
|
|
180
|
+
|
|
181
|
+
@torch.no_grad
|
|
182
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
183
|
+
if setting["merge_small"]:
|
|
184
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
185
|
+
|
|
186
|
+
if tensor.ndim <= 1 and not setting["precondition_1d"]:
|
|
187
|
+
state["accumulators"] = []
|
|
188
|
+
|
|
189
|
+
else:
|
|
190
|
+
max_dim = setting["max_dim"]
|
|
191
|
+
state['accumulators'] = [
|
|
192
|
+
torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
|
|
193
|
+
]
|
|
194
|
+
state['preconditioners'] = [
|
|
195
|
+
torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
# either scalar parameter, 1d with precondition_1d=False, or too big, then diagonal preconditioner is used.
|
|
199
|
+
if len([i is not None for i in state['accumulators']]) == 0:
|
|
200
|
+
state['diagonal_accumulator'] = torch.zeros_like(tensor)
|
|
201
|
+
|
|
202
|
+
state['step'] = 0
|
|
203
|
+
state["num_GTG"] = 0
|
|
204
|
+
|
|
205
|
+
@torch.no_grad
|
|
206
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
207
|
+
if setting["merge_small"]:
|
|
208
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
209
|
+
|
|
210
|
+
if 'diagonal_accumulator' in state:
|
|
211
|
+
update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
|
|
212
|
+
else:
|
|
213
|
+
update_shampoo_preconditioner_(
|
|
214
|
+
tensor,
|
|
215
|
+
accumulators_=state['accumulators'],
|
|
216
|
+
preconditioners_=state['preconditioners'],
|
|
217
|
+
step=state['step'],
|
|
218
|
+
precond_freq=setting["precond_freq"],
|
|
219
|
+
matrix_power=setting["matrix_power"],
|
|
220
|
+
beta=setting["beta"],
|
|
221
|
+
reg=setting["reg"],
|
|
222
|
+
matrix_power_method=setting["matrix_power_method"],
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if state["step"] % setting["precond_freq"] == 0:
|
|
226
|
+
state["num_GTG"] += 1
|
|
227
|
+
|
|
228
|
+
state["step"] += 1
|
|
229
|
+
|
|
230
|
+
@torch.no_grad
|
|
231
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
232
|
+
if setting["merge_small"]:
|
|
233
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
234
|
+
|
|
235
|
+
if 'diagonal_accumulator' in state:
|
|
236
|
+
dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
|
|
237
|
+
else:
|
|
238
|
+
dir = apply_shampoo_preconditioner(tensor, preconditioners_=state['preconditioners'])
|
|
239
|
+
|
|
240
|
+
if setting["merge_small"]:
|
|
241
|
+
dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
|
|
242
|
+
|
|
243
|
+
if setting['beta_debias'] and setting["beta"] is not None:
|
|
244
|
+
bias_correction = 1 - (setting["beta"] ** state["num_GTG"])
|
|
245
|
+
dir *= bias_correction ** 0.5
|
|
246
|
+
|
|
247
|
+
return dir
|
|
248
|
+
|