torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,49 +1,20 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from collections.abc import Callable
|
|
3
1
|
from typing import Literal
|
|
4
2
|
|
|
5
3
|
import torch
|
|
6
4
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import NumberList, TensorList,
|
|
5
|
+
from ...core import Chainable, HVPMethod, Transform
|
|
6
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
7
|
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
tensors_: TensorList,
|
|
13
|
-
D: TensorList | None,
|
|
14
|
-
D_sq_acc_: TensorList,
|
|
15
|
-
damping: float | NumberList,
|
|
16
|
-
update_freq: int,
|
|
17
|
-
step: int,
|
|
18
|
-
i: int,
|
|
19
|
-
):
|
|
20
|
-
# update preconditioner
|
|
21
|
-
if step % update_freq == 0:
|
|
22
|
-
assert D is not None
|
|
23
|
-
D_sq_acc_.addcmul_(D, D)
|
|
24
|
-
i += 1
|
|
25
|
-
else:
|
|
26
|
-
assert D is None
|
|
27
|
-
|
|
28
|
-
denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
|
|
29
|
-
return tensors_.div_(denom), i
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class ESGD(Module):
|
|
9
|
+
class ESGD(Transform):
|
|
33
10
|
"""Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
|
|
34
11
|
|
|
35
12
|
This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
|
|
36
13
|
|
|
37
|
-
|
|
38
|
-
In most cases
|
|
14
|
+
Notes:
|
|
15
|
+
- In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.
|
|
39
16
|
|
|
40
|
-
|
|
41
|
-
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
42
|
-
|
|
43
|
-
.. note::
|
|
44
|
-
This module requires a closure passed to the optimizer step,
|
|
45
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
46
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
17
|
+
- This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
47
18
|
|
|
48
19
|
Args:
|
|
49
20
|
damping (float, optional): added to denominator for stability. Defaults to 1e-4.
|
|
@@ -51,17 +22,17 @@ class ESGD(Module):
|
|
|
51
22
|
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
52
23
|
This value can be increased to reduce computational cost. Defaults to 20.
|
|
53
24
|
hvp_method (str, optional):
|
|
54
|
-
Determines how
|
|
55
|
-
|
|
56
|
-
- ``"
|
|
57
|
-
|
|
58
|
-
- ``"
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
25
|
+
Determines how hessian-vector products are computed.
|
|
26
|
+
|
|
27
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
|
|
28
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
29
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
30
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
31
|
+
|
|
32
|
+
Defaults to ``"autograd"``.
|
|
33
|
+
h (float, optional):
|
|
34
|
+
The step size for finite difference if ``hvp_method`` is
|
|
35
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
65
36
|
n_samples (int, optional):
|
|
66
37
|
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
67
38
|
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
@@ -72,100 +43,108 @@ class ESGD(Module):
|
|
|
72
43
|
2. pass inputs to :code:`inner`.
|
|
73
44
|
3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
|
|
74
45
|
|
|
75
|
-
Examples:
|
|
76
|
-
Using ESGD:
|
|
77
|
-
|
|
78
|
-
.. code-block:: python
|
|
46
|
+
### Examples:
|
|
79
47
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
tz.m.ESGD(),
|
|
83
|
-
tz.m.LR(0.1)
|
|
84
|
-
)
|
|
48
|
+
Using ESGD:
|
|
49
|
+
```python
|
|
85
50
|
|
|
86
|
-
|
|
87
|
-
|
|
51
|
+
opt = tz.Optimizer(
|
|
52
|
+
model.parameters(),
|
|
53
|
+
tz.m.ESGD(),
|
|
54
|
+
tz.m.LR(0.1)
|
|
55
|
+
)
|
|
56
|
+
```
|
|
88
57
|
|
|
89
|
-
|
|
58
|
+
ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
|
|
59
|
+
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
90
60
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
61
|
+
```python
|
|
62
|
+
opt = tz.Optimizer(
|
|
63
|
+
model.parameters(),
|
|
64
|
+
tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
|
|
65
|
+
tz.m.LR(0.1)
|
|
66
|
+
)
|
|
67
|
+
```
|
|
96
68
|
|
|
97
69
|
"""
|
|
98
70
|
def __init__(
|
|
99
71
|
self,
|
|
100
72
|
damping: float = 1e-4,
|
|
101
73
|
update_freq: int = 20,
|
|
102
|
-
|
|
103
|
-
|
|
74
|
+
distribution: Distributions = 'gaussian',
|
|
75
|
+
hvp_method: HVPMethod = 'autograd',
|
|
76
|
+
h: float = 1e-3,
|
|
104
77
|
n_samples = 1,
|
|
78
|
+
zHz: bool = False,
|
|
105
79
|
seed: int | None = None,
|
|
106
|
-
|
|
80
|
+
beta: float | None = None,
|
|
81
|
+
beta_debias: bool = True,
|
|
82
|
+
|
|
83
|
+
inner: Chainable | None = None,
|
|
84
|
+
Hz_sq_acc_tfm: Chainable | None = None,
|
|
107
85
|
):
|
|
108
|
-
defaults =
|
|
109
|
-
|
|
86
|
+
defaults = locals().copy()
|
|
87
|
+
del defaults['self'], defaults['inner'], defaults["Hz_sq_acc_tfm"]
|
|
88
|
+
super().__init__(defaults, inner=inner)
|
|
110
89
|
|
|
111
|
-
|
|
112
|
-
self.set_child('inner', inner)
|
|
90
|
+
self.set_child("Hz_sq_acc", Hz_sq_acc_tfm)
|
|
113
91
|
|
|
114
92
|
@torch.no_grad
|
|
115
|
-
def
|
|
116
|
-
params =
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
generator = None
|
|
125
|
-
if seed is not None:
|
|
126
|
-
if 'generator' not in self.global_state:
|
|
127
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
128
|
-
generator = self.global_state['generator']
|
|
129
|
-
|
|
130
|
-
damping = self.get_settings(params, 'damping', cls=NumberList)
|
|
131
|
-
D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
|
|
132
|
-
i = self.global_state.get('i', 0)
|
|
133
|
-
|
|
134
|
-
step = self.global_state.get('step', 0)
|
|
135
|
-
self.global_state['step'] = step + 1
|
|
136
|
-
|
|
137
|
-
closure = var.closure
|
|
138
|
-
assert closure is not None
|
|
139
|
-
|
|
140
|
-
D = None
|
|
93
|
+
def update_states(self, objective, states, settings):
|
|
94
|
+
params = objective.params
|
|
95
|
+
|
|
96
|
+
fs = settings[0]
|
|
97
|
+
update_freq = fs['update_freq']
|
|
98
|
+
|
|
99
|
+
# ------------------------------- accumulate Hz ------------------------------ #
|
|
100
|
+
step = self.increment_counter("step", start=0)
|
|
101
|
+
|
|
141
102
|
if step % update_freq == 0:
|
|
103
|
+
self.increment_counter("num_Hzs", start=1)
|
|
104
|
+
|
|
105
|
+
Hz, _ = objective.hutchinson_hessian(
|
|
106
|
+
rgrad = None,
|
|
107
|
+
at_x0 = True,
|
|
108
|
+
n_samples = fs['n_samples'],
|
|
109
|
+
distribution = fs['distribution'],
|
|
110
|
+
hvp_method = fs['hvp_method'],
|
|
111
|
+
h = fs['h'],
|
|
112
|
+
zHz = fs["zHz"], # default is False, so it returns Hz, not z⊙Hz
|
|
113
|
+
generator = self.get_generator(params[0].device, fs["seed"]),
|
|
114
|
+
)
|
|
142
115
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
116
|
+
Hz = TensorList(Hz)
|
|
117
|
+
Hz_sq_acc = unpack_states(states, params, 'Hz_sq_acc', cls=TensorList)
|
|
118
|
+
|
|
119
|
+
beta = fs["beta"]
|
|
120
|
+
if beta is None:
|
|
121
|
+
Hz_sq_acc.addcmul_(Hz, Hz)
|
|
122
|
+
|
|
123
|
+
else:
|
|
124
|
+
Hz_sq_acc.mul_(beta).addcmul_(Hz, Hz, value=1-beta)
|
|
125
|
+
|
|
126
|
+
@torch.no_grad
|
|
127
|
+
def apply_states(self, objective, states, settings):
|
|
128
|
+
tensors = TensorList(objective.get_updates())
|
|
129
|
+
Hz_sq_acc = unpack_states(states, tensors, 'Hz_sq_acc', cls=TensorList)
|
|
130
|
+
num_Hzs = self.global_state["num_Hzs"]
|
|
131
|
+
fs = settings[0]
|
|
146
132
|
|
|
147
|
-
|
|
148
|
-
|
|
133
|
+
# ---------------------------------- debias ---------------------------------- #
|
|
134
|
+
beta = fs["beta"]
|
|
135
|
+
beta_debias = fs["beta_debias"]
|
|
149
136
|
|
|
150
|
-
|
|
151
|
-
|
|
137
|
+
if beta_debias and beta is not None:
|
|
138
|
+
bias_correction = 1.0 - beta ** num_Hzs
|
|
139
|
+
Hz_sq_acc = Hz_sq_acc / bias_correction
|
|
152
140
|
|
|
153
|
-
|
|
154
|
-
|
|
141
|
+
else:
|
|
142
|
+
Hz_sq_acc = Hz_sq_acc / num_Hzs
|
|
155
143
|
|
|
156
|
-
|
|
144
|
+
# ---------------------------------- update ---------------------------------- #
|
|
145
|
+
damping = [s["damping"] for s in settings]
|
|
157
146
|
|
|
158
|
-
|
|
159
|
-
if 'inner' in self.children:
|
|
160
|
-
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
147
|
+
denom = (Hz_sq_acc / num_Hzs).sqrt_().add_(damping)
|
|
161
148
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
D=TensorList(D) if D is not None else None,
|
|
165
|
-
D_sq_acc_=D_sq_acc,
|
|
166
|
-
damping=damping,
|
|
167
|
-
update_freq=update_freq,
|
|
168
|
-
step=step,
|
|
169
|
-
i=i,
|
|
170
|
-
)
|
|
171
|
-
return var
|
|
149
|
+
objective.updates = tensors.div_(denom)
|
|
150
|
+
return objective
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from typing import Literal, Any
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
from ...linalg import torch_linalg, regularize_eigh
|
|
8
|
+
from .lre_optimizers import LREOptimizerBase
|
|
9
|
+
|
|
10
|
+
def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol):
|
|
11
|
+
"""returns U ``(ndim, rank)``, L ``(rank, )``"""
|
|
12
|
+
if isinstance(history, torch.Tensor):
|
|
13
|
+
M = history
|
|
14
|
+
else:
|
|
15
|
+
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
16
|
+
|
|
17
|
+
MtM = M.T @ M
|
|
18
|
+
if damping != 0:
|
|
19
|
+
MtM.add_(torch.eye(MtM.size(0), device=MtM.device, dtype=MtM.dtype).mul_(damping))
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
L, Q = torch_linalg.eigh(MtM, retry_float64=True)
|
|
23
|
+
|
|
24
|
+
# damping is already added to MTM, rdamping is added afterwards
|
|
25
|
+
L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eig_tol, damping=0, rdamping=0)
|
|
26
|
+
|
|
27
|
+
if L is None or Q is None: # this means there are no finite eigenvalues
|
|
28
|
+
return None, None
|
|
29
|
+
|
|
30
|
+
U = (M @ Q) * L.rsqrt()
|
|
31
|
+
|
|
32
|
+
# this damping is added after computing U, this is why I didn't use one in linalg.regularize_eig
|
|
33
|
+
# that's because we damp singular values this way
|
|
34
|
+
if rdamping != 0:
|
|
35
|
+
L.add_(rdamping * L[-1]) # L is sorted in ascending order
|
|
36
|
+
|
|
37
|
+
return L, U
|
|
38
|
+
|
|
39
|
+
except torch.linalg.LinAlgError:
|
|
40
|
+
return None, None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GGT(TensorTransform):
|
|
44
|
+
"""
|
|
45
|
+
GGT method from https://arxiv.org/pdf/1806.02958
|
|
46
|
+
|
|
47
|
+
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
|
|
48
|
+
But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
|
|
49
|
+
|
|
50
|
+
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
54
|
+
beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
|
|
55
|
+
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
56
|
+
eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
|
|
57
|
+
truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
|
|
58
|
+
damping (float, optional): damping value. Defaults to 1e-4.
|
|
59
|
+
rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
|
|
60
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
|
|
61
|
+
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
62
|
+
|
|
63
|
+
## Examples:
|
|
64
|
+
|
|
65
|
+
Limited-memory Adagrad
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
optimizer = tz.Optimizer(
|
|
69
|
+
model.parameters(),
|
|
70
|
+
tz.m.GGT(),
|
|
71
|
+
tz.m.LR(0.1)
|
|
72
|
+
)
|
|
73
|
+
```
|
|
74
|
+
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
optimizer = tz.Optimizer(
|
|
78
|
+
model.parameters(),
|
|
79
|
+
tz.m.GGT(inner=tz.m.EMA()),
|
|
80
|
+
tz.m.Debias(0.9, 0.999),
|
|
81
|
+
tz.m.LR(0.01)
|
|
82
|
+
)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
optimizer = tz.Optimizer(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.GGT(inner=tz.m.EMA()),
|
|
91
|
+
tz.m.Debias(0.9, 0.999),
|
|
92
|
+
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
93
|
+
tz.m.LR(0.01)
|
|
94
|
+
)
|
|
95
|
+
```
|
|
96
|
+
Reference:
|
|
97
|
+
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
history_size: int = 100,
|
|
103
|
+
update_freq: int = 1,
|
|
104
|
+
eig_tol: float = 1e-7,
|
|
105
|
+
truncate: int | None = None,
|
|
106
|
+
damping: float = 1e-4,
|
|
107
|
+
rdamping: float = 0,
|
|
108
|
+
eigenbasis_optimizer: LREOptimizerBase | None = None,
|
|
109
|
+
concat_params: bool = True,
|
|
110
|
+
|
|
111
|
+
inner: Chainable | None = None,
|
|
112
|
+
):
|
|
113
|
+
defaults = locals().copy()
|
|
114
|
+
del defaults['self'], defaults['inner'], defaults['concat_params']
|
|
115
|
+
|
|
116
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
117
|
+
|
|
118
|
+
@torch.no_grad
|
|
119
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
120
|
+
history_size = setting['history_size']
|
|
121
|
+
update_freq = setting['update_freq']
|
|
122
|
+
|
|
123
|
+
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
124
|
+
history = state['history']
|
|
125
|
+
|
|
126
|
+
t = tensor.clone().view(-1)
|
|
127
|
+
history.append(t)
|
|
128
|
+
|
|
129
|
+
step = state.get('step', 0)
|
|
130
|
+
state['step'] = step + 1
|
|
131
|
+
|
|
132
|
+
if step % update_freq == 0 :
|
|
133
|
+
|
|
134
|
+
# compute new factors
|
|
135
|
+
L = state.get("L", None)
|
|
136
|
+
U = state.get("U", None)
|
|
137
|
+
|
|
138
|
+
L_new, U_new = ggt_update(
|
|
139
|
+
history,
|
|
140
|
+
damping=setting["damping"],
|
|
141
|
+
rdamping=setting["rdamping"],
|
|
142
|
+
truncate=setting["truncate"],
|
|
143
|
+
eig_tol=setting["eig_tol"],
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# reproject eigenbasis optimizer
|
|
147
|
+
eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
|
|
148
|
+
if eigenbasis_optimizer is not None:
|
|
149
|
+
if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
|
|
150
|
+
eigenbasis_state = state["eigenbasis_state"]
|
|
151
|
+
eigenbasis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=eigenbasis_state)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# store new factors
|
|
155
|
+
if L_new is not None: state["L"] = L_new
|
|
156
|
+
if U_new is not None: state["U"] = U_new
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@torch.no_grad
|
|
160
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
161
|
+
g = tensor.view(-1)
|
|
162
|
+
U = state.get('U', None)
|
|
163
|
+
|
|
164
|
+
if U is None:
|
|
165
|
+
# fallback to element-wise preconditioning
|
|
166
|
+
history = torch.stack(tuple(state["history"]), 0)
|
|
167
|
+
g /= history.square().mean(0).sqrt().add(1e-8)
|
|
168
|
+
return g.view_as(tensor)
|
|
169
|
+
|
|
170
|
+
L = state['L']
|
|
171
|
+
|
|
172
|
+
# step with eigenbasis optimizer
|
|
173
|
+
eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
|
|
174
|
+
if eigenbasis_optimizer is not None:
|
|
175
|
+
|
|
176
|
+
if "eigenbasis_state" not in state: state["eigenbasis_state"] = {}
|
|
177
|
+
eigenbasis_state = state["eigenbasis_state"]
|
|
178
|
+
|
|
179
|
+
update = eigenbasis_optimizer.step(g, L=L, Q=U, state=eigenbasis_state)
|
|
180
|
+
return update.view_as(tensor)
|
|
181
|
+
|
|
182
|
+
# or just whiten
|
|
183
|
+
z = U.T @ g
|
|
184
|
+
update = (U * L.rsqrt()) @ z
|
|
185
|
+
return update.view_as(tensor)
|
|
186
|
+
|
|
@@ -1,21 +1,17 @@
|
|
|
1
|
+
from typing import Any
|
|
1
2
|
import torch
|
|
2
3
|
|
|
3
|
-
from ...core import
|
|
4
|
+
from ...core import TensorTransform
|
|
4
5
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
|
|
8
|
-
"""
|
|
9
|
-
Lion update rule.
|
|
10
|
-
|
|
11
|
-
Returns new tensors.
|
|
12
|
-
"""
|
|
8
|
+
def lion_(tensors: TensorList | Any, exp_avg_: TensorList | Any, beta1, beta2,):
|
|
13
9
|
update = exp_avg_.lerp(tensors, 1-beta1).sign_()
|
|
14
10
|
exp_avg_.lerp_(tensors, 1-beta2)
|
|
15
11
|
return update
|
|
16
12
|
|
|
17
13
|
|
|
18
|
-
class Lion(
|
|
14
|
+
class Lion(TensorTransform):
|
|
19
15
|
"""Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
|
|
20
16
|
|
|
21
17
|
Args:
|
|
@@ -25,11 +21,11 @@ class Lion(Transform):
|
|
|
25
21
|
|
|
26
22
|
def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
|
|
27
23
|
defaults = dict(beta1=beta1, beta2=beta2)
|
|
28
|
-
super().__init__(defaults
|
|
24
|
+
super().__init__(defaults)
|
|
29
25
|
|
|
30
26
|
@torch.no_grad
|
|
31
|
-
def
|
|
27
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
32
28
|
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
33
29
|
exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
34
|
-
return lion_(TensorList(tensors),exp_avg,beta1,beta2)
|
|
30
|
+
return lion_(TensorList(tensors), exp_avg, beta1, beta2)
|
|
35
31
|
|