torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,49 +1,20 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from collections.abc import Callable
|
|
3
1
|
from typing import Literal
|
|
4
2
|
|
|
5
3
|
import torch
|
|
6
4
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import NumberList, TensorList,
|
|
5
|
+
from ...core import Chainable, HVPMethod, Transform
|
|
6
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
7
|
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
tensors_: TensorList,
|
|
13
|
-
D: TensorList | None,
|
|
14
|
-
D_sq_acc_: TensorList,
|
|
15
|
-
damping: float | NumberList,
|
|
16
|
-
update_freq: int,
|
|
17
|
-
step: int,
|
|
18
|
-
i: int,
|
|
19
|
-
):
|
|
20
|
-
# update preconditioner
|
|
21
|
-
if step % update_freq == 0:
|
|
22
|
-
assert D is not None
|
|
23
|
-
D_sq_acc_.addcmul_(D, D)
|
|
24
|
-
i += 1
|
|
25
|
-
else:
|
|
26
|
-
assert D is None
|
|
27
|
-
|
|
28
|
-
denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
|
|
29
|
-
return tensors_.div_(denom), i
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class ESGD(Module):
|
|
9
|
+
class ESGD(Transform):
|
|
33
10
|
"""Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
|
|
34
11
|
|
|
35
12
|
This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
|
|
36
13
|
|
|
37
|
-
|
|
38
|
-
In most cases
|
|
14
|
+
Notes:
|
|
15
|
+
- In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.
|
|
39
16
|
|
|
40
|
-
|
|
41
|
-
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
42
|
-
|
|
43
|
-
.. note::
|
|
44
|
-
This module requires a closure passed to the optimizer step,
|
|
45
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
46
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
17
|
+
- This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
47
18
|
|
|
48
19
|
Args:
|
|
49
20
|
damping (float, optional): added to denominator for stability. Defaults to 1e-4.
|
|
@@ -51,17 +22,17 @@ class ESGD(Module):
|
|
|
51
22
|
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
52
23
|
This value can be increased to reduce computational cost. Defaults to 20.
|
|
53
24
|
hvp_method (str, optional):
|
|
54
|
-
Determines how
|
|
55
|
-
|
|
56
|
-
- ``"
|
|
57
|
-
|
|
58
|
-
- ``"
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
25
|
+
Determines how hessian-vector products are computed.
|
|
26
|
+
|
|
27
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
|
|
28
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
29
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
30
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
31
|
+
|
|
32
|
+
Defaults to ``"autograd"``.
|
|
33
|
+
h (float, optional):
|
|
34
|
+
The step size for finite difference if ``hvp_method`` is
|
|
35
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
65
36
|
n_samples (int, optional):
|
|
66
37
|
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
67
38
|
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
@@ -72,100 +43,108 @@ class ESGD(Module):
|
|
|
72
43
|
2. pass inputs to :code:`inner`.
|
|
73
44
|
3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
|
|
74
45
|
|
|
75
|
-
Examples:
|
|
76
|
-
Using ESGD:
|
|
77
|
-
|
|
78
|
-
.. code-block:: python
|
|
46
|
+
### Examples:
|
|
79
47
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
tz.m.ESGD(),
|
|
83
|
-
tz.m.LR(0.1)
|
|
84
|
-
)
|
|
48
|
+
Using ESGD:
|
|
49
|
+
```python
|
|
85
50
|
|
|
86
|
-
|
|
87
|
-
|
|
51
|
+
opt = tz.Modular(
|
|
52
|
+
model.parameters(),
|
|
53
|
+
tz.m.ESGD(),
|
|
54
|
+
tz.m.LR(0.1)
|
|
55
|
+
)
|
|
56
|
+
```
|
|
88
57
|
|
|
89
|
-
|
|
58
|
+
ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
|
|
59
|
+
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
90
60
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
61
|
+
```python
|
|
62
|
+
opt = tz.Modular(
|
|
63
|
+
model.parameters(),
|
|
64
|
+
tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
|
|
65
|
+
tz.m.LR(0.1)
|
|
66
|
+
)
|
|
67
|
+
```
|
|
96
68
|
|
|
97
69
|
"""
|
|
98
70
|
def __init__(
|
|
99
71
|
self,
|
|
100
72
|
damping: float = 1e-4,
|
|
101
73
|
update_freq: int = 20,
|
|
102
|
-
|
|
103
|
-
|
|
74
|
+
distribution: Distributions = 'gaussian',
|
|
75
|
+
hvp_method: HVPMethod = 'autograd',
|
|
76
|
+
h: float = 1e-3,
|
|
104
77
|
n_samples = 1,
|
|
78
|
+
zHz: bool = False,
|
|
105
79
|
seed: int | None = None,
|
|
106
|
-
|
|
80
|
+
beta: float | None = None,
|
|
81
|
+
beta_debias: bool = True,
|
|
82
|
+
|
|
83
|
+
inner: Chainable | None = None,
|
|
84
|
+
Hz_sq_acc_tfm: Chainable | None = None,
|
|
107
85
|
):
|
|
108
|
-
defaults =
|
|
109
|
-
|
|
86
|
+
defaults = locals().copy()
|
|
87
|
+
del defaults['self'], defaults['inner'], defaults["Hz_sq_acc_tfm"]
|
|
88
|
+
super().__init__(defaults, inner=inner)
|
|
110
89
|
|
|
111
|
-
|
|
112
|
-
self.set_child('inner', inner)
|
|
90
|
+
self.set_child("Hz_sq_acc", Hz_sq_acc_tfm)
|
|
113
91
|
|
|
114
92
|
@torch.no_grad
|
|
115
|
-
def
|
|
116
|
-
params =
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
generator = None
|
|
125
|
-
if seed is not None:
|
|
126
|
-
if 'generator' not in self.global_state:
|
|
127
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
128
|
-
generator = self.global_state['generator']
|
|
129
|
-
|
|
130
|
-
damping = self.get_settings(params, 'damping', cls=NumberList)
|
|
131
|
-
D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
|
|
132
|
-
i = self.global_state.get('i', 0)
|
|
133
|
-
|
|
134
|
-
step = self.global_state.get('step', 0)
|
|
135
|
-
self.global_state['step'] = step + 1
|
|
136
|
-
|
|
137
|
-
closure = var.closure
|
|
138
|
-
assert closure is not None
|
|
139
|
-
|
|
140
|
-
D = None
|
|
93
|
+
def update_states(self, objective, states, settings):
|
|
94
|
+
params = objective.params
|
|
95
|
+
|
|
96
|
+
fs = settings[0]
|
|
97
|
+
update_freq = fs['update_freq']
|
|
98
|
+
|
|
99
|
+
# ------------------------------- accumulate Hz ------------------------------ #
|
|
100
|
+
step = self.increment_counter("step", start=0)
|
|
101
|
+
|
|
141
102
|
if step % update_freq == 0:
|
|
103
|
+
self.increment_counter("num_Hzs", start=1)
|
|
104
|
+
|
|
105
|
+
Hz, _ = objective.hutchinson_hessian(
|
|
106
|
+
rgrad = None,
|
|
107
|
+
at_x0 = True,
|
|
108
|
+
n_samples = fs['n_samples'],
|
|
109
|
+
distribution = fs['distribution'],
|
|
110
|
+
hvp_method = fs['hvp_method'],
|
|
111
|
+
h = fs['h'],
|
|
112
|
+
zHz = fs["zHz"], # default is False, so it returns Hz, not z⊙Hz
|
|
113
|
+
generator = self.get_generator(params[0].device, fs["seed"]),
|
|
114
|
+
)
|
|
142
115
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
116
|
+
Hz = TensorList(Hz)
|
|
117
|
+
Hz_sq_acc = unpack_states(states, params, 'Hz_sq_acc', cls=TensorList)
|
|
118
|
+
|
|
119
|
+
beta = fs["beta"]
|
|
120
|
+
if beta is None:
|
|
121
|
+
Hz_sq_acc.addcmul_(Hz, Hz)
|
|
122
|
+
|
|
123
|
+
else:
|
|
124
|
+
Hz_sq_acc.mul_(beta).addcmul_(Hz, Hz, value=1-beta)
|
|
125
|
+
|
|
126
|
+
@torch.no_grad
|
|
127
|
+
def apply_states(self, objective, states, settings):
|
|
128
|
+
tensors = TensorList(objective.get_updates())
|
|
129
|
+
Hz_sq_acc = unpack_states(states, tensors, 'Hz_sq_acc', cls=TensorList)
|
|
130
|
+
num_Hzs = self.global_state["num_Hzs"]
|
|
131
|
+
fs = settings[0]
|
|
146
132
|
|
|
147
|
-
|
|
148
|
-
|
|
133
|
+
# ---------------------------------- debias ---------------------------------- #
|
|
134
|
+
beta = fs["beta"]
|
|
135
|
+
beta_debias = fs["beta_debias"]
|
|
149
136
|
|
|
150
|
-
|
|
151
|
-
|
|
137
|
+
if beta_debias and beta is not None:
|
|
138
|
+
bias_correction = 1.0 - beta ** num_Hzs
|
|
139
|
+
Hz_sq_acc = Hz_sq_acc / bias_correction
|
|
152
140
|
|
|
153
|
-
|
|
154
|
-
|
|
141
|
+
else:
|
|
142
|
+
Hz_sq_acc = Hz_sq_acc / num_Hzs
|
|
155
143
|
|
|
156
|
-
|
|
144
|
+
# ---------------------------------- update ---------------------------------- #
|
|
145
|
+
damping = [s["damping"] for s in settings]
|
|
157
146
|
|
|
158
|
-
|
|
159
|
-
if 'inner' in self.children:
|
|
160
|
-
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
147
|
+
denom = (Hz_sq_acc / num_Hzs).sqrt_().add_(damping)
|
|
161
148
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
D=TensorList(D) if D is not None else None,
|
|
165
|
-
D_sq_acc_=D_sq_acc,
|
|
166
|
-
damping=damping,
|
|
167
|
-
update_freq=update_freq,
|
|
168
|
-
step=step,
|
|
169
|
-
i=i,
|
|
170
|
-
)
|
|
171
|
-
return var
|
|
149
|
+
objective.updates = tensors.div_(denom)
|
|
150
|
+
return objective
|
|
@@ -1,21 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
3
|
+
from ...core import TensorTransform
|
|
4
4
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
|
|
8
|
-
"""
|
|
9
|
-
Lion update rule.
|
|
10
|
-
|
|
11
|
-
Returns new tensors.
|
|
12
|
-
"""
|
|
13
8
|
update = exp_avg_.lerp(tensors, 1-beta1).sign_()
|
|
14
9
|
exp_avg_.lerp_(tensors, 1-beta2)
|
|
15
10
|
return update
|
|
16
11
|
|
|
17
12
|
|
|
18
|
-
class Lion(
|
|
13
|
+
class Lion(TensorTransform):
|
|
19
14
|
"""Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
|
|
20
15
|
|
|
21
16
|
Args:
|
|
@@ -25,11 +20,11 @@ class Lion(Transform):
|
|
|
25
20
|
|
|
26
21
|
def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
|
|
27
22
|
defaults = dict(beta1=beta1, beta2=beta2)
|
|
28
|
-
super().__init__(defaults
|
|
23
|
+
super().__init__(defaults)
|
|
29
24
|
|
|
30
25
|
@torch.no_grad
|
|
31
|
-
def
|
|
26
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
32
27
|
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
33
28
|
exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
34
|
-
return lion_(TensorList(tensors),exp_avg,beta1,beta2)
|
|
29
|
+
return lion_(TensorList(tensors), exp_avg, beta1, beta2)
|
|
35
30
|
|
|
@@ -3,9 +3,11 @@ from typing import Literal, Any
|
|
|
3
3
|
import warnings
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from ...core import Chainable,
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
from ...linalg import torch_linalg
|
|
7
8
|
|
|
8
|
-
def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping):
|
|
9
|
+
def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, tol):
|
|
10
|
+
"""returns U ``(ndim, rank)``, L ``(rank, )``"""
|
|
9
11
|
if isinstance(history, torch.Tensor):
|
|
10
12
|
M = history
|
|
11
13
|
else:
|
|
@@ -16,35 +18,49 @@ def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdam
|
|
|
16
18
|
MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
|
|
17
19
|
|
|
18
20
|
try:
|
|
19
|
-
L, Q =
|
|
21
|
+
L, Q = torch_linalg.eigh(MTM, retry_float64=True)
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
23
|
+
# truncate to top n largest eigenvalues
|
|
24
|
+
if truncate is not None and truncate > 0:
|
|
25
|
+
# L is ordered in ascending order
|
|
26
|
+
L = L[-truncate:]
|
|
27
|
+
Q = Q[:, -truncate:]
|
|
28
|
+
|
|
29
|
+
# remove small eigenvalues relative to largest
|
|
30
|
+
L_max = L.amax()
|
|
31
|
+
indices = L > tol * L_max
|
|
32
|
+
if indices.any():
|
|
33
|
+
L = L[indices]
|
|
34
|
+
Q = Q[:, indices]
|
|
25
35
|
|
|
26
36
|
U = (M @ Q) * L.rsqrt()
|
|
27
37
|
|
|
28
38
|
if rdamping != 0:
|
|
29
|
-
rdamping
|
|
30
|
-
L.add_(rdamping)
|
|
39
|
+
L.add_(rdamping * L_max)
|
|
31
40
|
|
|
32
41
|
return U, L
|
|
33
42
|
|
|
34
43
|
except torch.linalg.LinAlgError:
|
|
35
44
|
return None, None
|
|
36
45
|
|
|
37
|
-
def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
|
|
38
|
-
|
|
39
|
-
|
|
46
|
+
def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor, exp_avg_proj: torch.Tensor | None, beta:float):
|
|
47
|
+
z = U.T @ g
|
|
48
|
+
|
|
49
|
+
if beta != 0:
|
|
50
|
+
if exp_avg_proj is None: exp_avg_proj = torch.zeros_like(z)
|
|
51
|
+
exp_avg_proj.lerp_(z, weight=1-beta)
|
|
52
|
+
z = exp_avg_proj
|
|
53
|
+
|
|
54
|
+
return (U * L.rsqrt()) @ z, exp_avg_proj
|
|
40
55
|
|
|
41
56
|
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
42
|
-
if
|
|
57
|
+
if value is None: return
|
|
58
|
+
if (key not in state_) or (beta is None): state_[key] = value
|
|
43
59
|
else:
|
|
44
60
|
if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
|
|
45
61
|
else: state_[key].lerp_(value, 1-beta)
|
|
46
62
|
|
|
47
|
-
class LMAdagrad(
|
|
63
|
+
class LMAdagrad(TensorTransform):
|
|
48
64
|
"""
|
|
49
65
|
Limited-memory full matrix Adagrad.
|
|
50
66
|
|
|
@@ -55,17 +71,18 @@ class LMAdagrad(TensorwiseTransform):
|
|
|
55
71
|
|
|
56
72
|
Args:
|
|
57
73
|
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
74
|
+
beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
|
|
58
75
|
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
59
76
|
damping (float, optional): damping value. Defaults to 1e-4.
|
|
60
77
|
rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
|
|
78
|
+
rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
|
|
79
|
+
truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
|
|
80
|
+
tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
|
|
61
81
|
order (int, optional):
|
|
62
82
|
order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
|
|
63
|
-
true_damping (bool, optional):
|
|
64
|
-
If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
|
|
65
83
|
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
66
84
|
L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
|
|
67
|
-
|
|
68
|
-
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
|
|
85
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
|
|
69
86
|
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
70
87
|
|
|
71
88
|
## Examples:
|
|
@@ -108,28 +125,35 @@ class LMAdagrad(TensorwiseTransform):
|
|
|
108
125
|
def __init__(
|
|
109
126
|
self,
|
|
110
127
|
history_size: int = 100,
|
|
128
|
+
beta: float = 0.0,
|
|
111
129
|
update_freq: int = 1,
|
|
112
130
|
damping: float = 1e-4,
|
|
113
131
|
rdamping: float = 0,
|
|
132
|
+
truncate: int | None = None,
|
|
133
|
+
tol: float = 1e-7,
|
|
114
134
|
order: int = 1,
|
|
115
|
-
true_damping: bool = True,
|
|
116
135
|
U_beta: float | None = None,
|
|
117
136
|
L_beta: float | None = None,
|
|
118
|
-
interval: int = 1,
|
|
119
137
|
concat_params: bool = True,
|
|
138
|
+
|
|
120
139
|
inner: Chainable | None = None,
|
|
140
|
+
U_tfm: Chainable | None = None,
|
|
141
|
+
L_tfm: Chainable | None = None,
|
|
121
142
|
):
|
|
122
|
-
|
|
123
|
-
defaults
|
|
124
|
-
|
|
143
|
+
defaults = locals().copy()
|
|
144
|
+
del defaults['self'], defaults['inner'], defaults['concat_params'], defaults["U_tfm"], defaults["L_tfm"]
|
|
145
|
+
|
|
146
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
147
|
+
|
|
148
|
+
self.set_child("U", U_tfm)
|
|
149
|
+
self.set_child("L", L_tfm)
|
|
150
|
+
|
|
125
151
|
|
|
126
152
|
@torch.no_grad
|
|
127
|
-
def
|
|
153
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
128
154
|
order = setting['order']
|
|
129
155
|
history_size = setting['history_size']
|
|
130
156
|
update_freq = setting['update_freq']
|
|
131
|
-
damping = setting['damping']
|
|
132
|
-
rdamping = setting['rdamping']
|
|
133
157
|
U_beta = setting['U_beta']
|
|
134
158
|
L_beta = setting['L_beta']
|
|
135
159
|
|
|
@@ -165,22 +189,53 @@ class LMAdagrad(TensorwiseTransform):
|
|
|
165
189
|
|
|
166
190
|
step = state.get('step', 0)
|
|
167
191
|
if step % update_freq == 0 and len(history) != 0:
|
|
168
|
-
|
|
192
|
+
|
|
193
|
+
# if maintaining momentum, unproject exp_avg before updating factors and reproject
|
|
194
|
+
exp_avg_proj = state.get("exp_avg_proj", None)
|
|
195
|
+
exp_avg = None
|
|
196
|
+
if exp_avg_proj is not None and "U" in state:
|
|
197
|
+
exp_avg = state["U"] @ exp_avg_proj
|
|
198
|
+
|
|
199
|
+
# update factors
|
|
200
|
+
U, L = lm_adagrad_update(
|
|
201
|
+
history,
|
|
202
|
+
damping=setting["damping"],
|
|
203
|
+
rdamping=setting["rdamping"],
|
|
204
|
+
truncate=setting["truncate"],
|
|
205
|
+
tol=setting["tol"],
|
|
206
|
+
)
|
|
169
207
|
maybe_lerp_(state, U_beta, 'U', U)
|
|
170
208
|
maybe_lerp_(state, L_beta, 'L', L)
|
|
171
209
|
|
|
210
|
+
# re-project exp_avg with new factors
|
|
211
|
+
if U is not None and exp_avg_proj is not None:
|
|
212
|
+
assert exp_avg is not None
|
|
213
|
+
state["exp_avg_proj"] = U.T @ exp_avg
|
|
214
|
+
|
|
215
|
+
|
|
172
216
|
if len(history) != 0:
|
|
173
217
|
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
174
218
|
|
|
175
219
|
@torch.no_grad
|
|
176
|
-
def
|
|
220
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
177
221
|
U = state.get('U', None)
|
|
178
222
|
if U is None:
|
|
179
223
|
# make a conservative step to avoid issues due to different GD scaling
|
|
180
|
-
return tensor.clip_(-0.1, 0.1)
|
|
224
|
+
return tensor.clip_(-0.1, 0.1)
|
|
181
225
|
|
|
226
|
+
# -------------------------------- transforms -------------------------------- #
|
|
182
227
|
L = state['L']
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
228
|
+
if "L" in self.children:
|
|
229
|
+
if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
|
|
230
|
+
L = self.inner_step_tensors("L", [L], clone=True)[0]
|
|
231
|
+
|
|
232
|
+
if "U" in self.children:
|
|
233
|
+
if not self._concat_params: raise RuntimeError("L/U transforms can only be used with concat_params=True")
|
|
234
|
+
U = self.inner_step_tensors("U", [U], clone=True)[0]
|
|
235
|
+
|
|
236
|
+
# ------------------------------- precondition ------------------------------- #
|
|
237
|
+
g = tensor.view(-1)
|
|
238
|
+
exp_avg_proj = state.get("exp_avg_proj", None)
|
|
239
|
+
update, state["exp_avg_proj"] = lm_adagrad_apply(g, U, L, exp_avg_proj, beta=setting["beta"])
|
|
240
|
+
return update.view_as(tensor)
|
|
186
241
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
3
|
+
from ...core import TensorTransform
|
|
4
4
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
5
|
|
|
6
6
|
|
|
@@ -20,7 +20,7 @@ def mars_correction_(
|
|
|
20
20
|
|
|
21
21
|
return c
|
|
22
22
|
|
|
23
|
-
class MARSCorrection(
|
|
23
|
+
class MARSCorrection(TensorTransform):
|
|
24
24
|
"""MARS variance reduction correction.
|
|
25
25
|
|
|
26
26
|
Place any other momentum-based optimizer after this,
|
|
@@ -61,11 +61,11 @@ class MARSCorrection(Transform):
|
|
|
61
61
|
scaling: float = 0.025,
|
|
62
62
|
max_norm: float | None = 1,
|
|
63
63
|
):
|
|
64
|
-
defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
65
|
-
super().__init__(defaults
|
|
64
|
+
defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
65
|
+
super().__init__(defaults)
|
|
66
66
|
|
|
67
67
|
@torch.no_grad
|
|
68
|
-
def
|
|
68
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
69
69
|
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
70
70
|
beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
|
|
71
71
|
max_norm = settings[0]['max_norm']
|