torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,49 +1,146 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import (
|
|
6
|
+
Chainable,
|
|
7
|
+
Module,
|
|
8
|
+
Preconditioner,
|
|
9
|
+
Target,
|
|
10
|
+
TensorwisePreconditioner,
|
|
11
|
+
Transform,
|
|
12
|
+
Vars,
|
|
13
|
+
apply,
|
|
14
|
+
)
|
|
15
|
+
from ...utils import NumberList, TensorList
|
|
16
|
+
from ...utils.linalg import matrix_power_eigh
|
|
17
|
+
from ..functional import add_power_, lerp_power_, root
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def adagrad_(
|
|
21
|
+
tensors_: TensorList,
|
|
22
|
+
sq_sum_: TensorList,
|
|
23
|
+
alpha: float | NumberList,
|
|
24
|
+
lr_decay: float | NumberList,
|
|
25
|
+
eps: float | NumberList,
|
|
26
|
+
step: int,
|
|
27
|
+
pow: float = 2,
|
|
28
|
+
use_sqrt: bool = True,
|
|
29
|
+
|
|
30
|
+
# inner args
|
|
31
|
+
inner: Module | None = None,
|
|
32
|
+
params: list[torch.Tensor] | None = None,
|
|
33
|
+
grads: list[torch.Tensor] | None = None,
|
|
34
|
+
vars: Vars | None = None,
|
|
35
|
+
):
|
|
36
|
+
"""returns `tensors_`"""
|
|
37
|
+
clr = alpha / (1 + step * lr_decay)
|
|
38
|
+
|
|
39
|
+
sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
|
|
40
|
+
|
|
41
|
+
if inner is not None:
|
|
42
|
+
assert params is not None
|
|
43
|
+
tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
|
|
44
|
+
|
|
45
|
+
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
46
|
+
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
47
|
+
|
|
48
|
+
return tensors_
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Adagrad(Transform):
|
|
53
|
+
"""Adagrad, divides by sum of past squares of gradients, matches pytorch Adagrad.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
57
|
+
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
58
|
+
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
59
|
+
alpha (float, optional): step size. Defaults to 1.
|
|
60
|
+
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
61
|
+
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
62
|
+
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
63
|
+
"""
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
lr_decay: float = 0,
|
|
67
|
+
initial_accumulator_value: float = 0,
|
|
68
|
+
eps: float = 1e-10,
|
|
69
|
+
alpha: float = 1,
|
|
70
|
+
pow: float = 2,
|
|
71
|
+
use_sqrt: bool = True,
|
|
72
|
+
inner: Chainable | None = None,
|
|
73
|
+
):
|
|
74
|
+
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
75
|
+
eps = eps, pow=pow, use_sqrt = use_sqrt)
|
|
76
|
+
super().__init__(defaults=defaults, uses_grad=False)
|
|
77
|
+
|
|
78
|
+
if inner is not None:
|
|
79
|
+
self.set_child('inner', inner)
|
|
80
|
+
|
|
81
|
+
@torch.no_grad
|
|
82
|
+
def transform(self, tensors, params, grads, vars):
|
|
83
|
+
tensors = TensorList(tensors)
|
|
84
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
85
|
+
|
|
86
|
+
lr_decay,alpha,eps = self.get_settings('lr_decay', 'alpha', 'eps', params=params, cls=NumberList)
|
|
87
|
+
|
|
88
|
+
pow, use_sqrt = itemgetter('pow', 'use_sqrt')(self.settings[params[0]])
|
|
89
|
+
|
|
90
|
+
sq_sum = self.get_state('sq_sum', params=params, cls=TensorList)
|
|
91
|
+
|
|
92
|
+
# initialize accumulator on 1st step
|
|
93
|
+
if step == 1:
|
|
94
|
+
sq_sum.set_(tensors.full_like(self.get_settings('initial_accumulator_value', params=params)))
|
|
95
|
+
|
|
96
|
+
return adagrad_(
|
|
97
|
+
tensors,
|
|
98
|
+
sq_sum_=sq_sum,
|
|
99
|
+
alpha=alpha,
|
|
100
|
+
lr_decay=lr_decay,
|
|
101
|
+
eps=eps,
|
|
102
|
+
step=self.global_state["step"],
|
|
103
|
+
pow=pow,
|
|
104
|
+
use_sqrt=use_sqrt,
|
|
105
|
+
|
|
106
|
+
# inner args
|
|
107
|
+
inner=self.children.get("inner", None),
|
|
108
|
+
params=params,
|
|
109
|
+
grads=grads,
|
|
110
|
+
vars=vars,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class FullMatrixAdagrad(TensorwisePreconditioner):
|
|
116
|
+
def __init__(self, beta: float | None = None, decay: float | None = None, concat_params=False, update_freq=1, inner: Chainable | None = None):
|
|
117
|
+
defaults = dict(beta=beta, decay=decay)
|
|
118
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
119
|
+
|
|
120
|
+
@torch.no_grad
|
|
121
|
+
def update_tensor(self, tensor, param, grad, state, settings):
|
|
122
|
+
G = tensor.ravel()
|
|
123
|
+
GG = torch.outer(G, G)
|
|
124
|
+
decay = settings['decay']
|
|
125
|
+
beta = settings['beta']
|
|
126
|
+
|
|
127
|
+
if 'GG' not in state: state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
128
|
+
if decay is not None: state['GG'].mul_(decay)
|
|
129
|
+
|
|
130
|
+
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
131
|
+
else: state['GG'].add_(GG)
|
|
132
|
+
|
|
133
|
+
@torch.no_grad
|
|
134
|
+
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
135
|
+
GG = state['GG']
|
|
136
|
+
|
|
137
|
+
if tensor.numel() == 1:
|
|
138
|
+
return tensor / (GG**(1/2)).squeeze()
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
B = matrix_power_eigh(GG, -1/2)
|
|
142
|
+
except torch.linalg.LinAlgError:
|
|
143
|
+
return tensor.div_(tensor.abs().max()) # conservative scaling
|
|
144
|
+
|
|
145
|
+
return (B @ tensor.ravel()).view_as(tensor)
|
|
146
|
+
|
|
@@ -1,118 +1,112 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from ...core import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
#
|
|
88
|
-
if vars.
|
|
89
|
-
if
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
# next module is either None or LR
|
|
114
|
-
if self.next_module is None: return vars.get_loss()
|
|
115
|
-
|
|
116
|
-
# step with LR, which has _skip = True so it won't apply lr, but may step with the scheduler
|
|
117
|
-
self.next_module._update(vars, None) # type:ignore
|
|
118
|
-
return vars.get_loss()
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform
|
|
7
|
+
from ...utils import NumberList, TensorList
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..lr.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def adam_(
|
|
19
|
+
tensors: TensorList,
|
|
20
|
+
exp_avg_: TensorList,
|
|
21
|
+
exp_avg_sq_: TensorList,
|
|
22
|
+
alpha: float | NumberList,
|
|
23
|
+
beta1: float | NumberList,
|
|
24
|
+
beta2: float | NumberList,
|
|
25
|
+
eps: float | NumberList,
|
|
26
|
+
step: int,
|
|
27
|
+
pow: float = 2,
|
|
28
|
+
debiased: bool = True,
|
|
29
|
+
max_exp_avg_sq_: TensorList | None = None,
|
|
30
|
+
params_: TensorList | None = None,
|
|
31
|
+
):
|
|
32
|
+
"""Returns new tensors or updates params in-place."""
|
|
33
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
34
|
+
|
|
35
|
+
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
36
|
+
debiased=False,step=step,pow=pow)
|
|
37
|
+
|
|
38
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
39
|
+
|
|
40
|
+
# params is None, return update
|
|
41
|
+
if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
|
|
42
|
+
|
|
43
|
+
# update params in-place
|
|
44
|
+
params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
class Adam(Module):
|
|
48
|
+
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
|
|
49
|
+
pytorch in that debiasing is applied after adding epsilon.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
53
|
+
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
54
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
55
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
56
|
+
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
57
|
+
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
58
|
+
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
59
|
+
"""
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
beta1: float = 0.9,
|
|
63
|
+
beta2: float = 0.999,
|
|
64
|
+
eps: float = 1e-8,
|
|
65
|
+
amsgrad: bool = False,
|
|
66
|
+
alpha: float = 1.,
|
|
67
|
+
pow: float = 2,
|
|
68
|
+
debiased: bool = True,
|
|
69
|
+
):
|
|
70
|
+
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
71
|
+
super().__init__(defaults)
|
|
72
|
+
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
73
|
+
|
|
74
|
+
@torch.no_grad
|
|
75
|
+
def step(self, vars):
|
|
76
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
77
|
+
|
|
78
|
+
beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
|
|
79
|
+
amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
|
|
80
|
+
|
|
81
|
+
if amsgrad:
|
|
82
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
|
|
83
|
+
else:
|
|
84
|
+
exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
|
|
85
|
+
max_exp_avg_sq = None
|
|
86
|
+
|
|
87
|
+
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
88
|
+
if vars.is_last:
|
|
89
|
+
if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
|
|
90
|
+
passed_params = TensorList(vars.params)
|
|
91
|
+
vars.stop = True
|
|
92
|
+
vars.skip_update = True
|
|
93
|
+
|
|
94
|
+
else:
|
|
95
|
+
passed_params = None
|
|
96
|
+
|
|
97
|
+
vars.update = adam_(
|
|
98
|
+
tensors=TensorList(vars.get_update()),
|
|
99
|
+
exp_avg_=exp_avg,
|
|
100
|
+
exp_avg_sq_=exp_avg_sq,
|
|
101
|
+
alpha=alpha,
|
|
102
|
+
beta1=beta1,
|
|
103
|
+
beta2=beta2,
|
|
104
|
+
eps=eps,
|
|
105
|
+
step=step,
|
|
106
|
+
pow=pow,
|
|
107
|
+
debiased=debiased,
|
|
108
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
109
|
+
params_=passed_params,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return vars
|
|
@@ -1,15 +1,21 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
4
|
-
from ...
|
|
3
|
+
from ...core import Module, Target, Transform
|
|
4
|
+
from ...utils import NumberList, TensorList
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
def
|
|
8
|
-
|
|
9
|
-
|
|
7
|
+
def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
|
|
8
|
+
"""
|
|
9
|
+
Lion update rule.
|
|
10
|
+
|
|
11
|
+
Returns new tensors.
|
|
12
|
+
"""
|
|
13
|
+
update = exp_avg_.lerp(tensors, 1-beta1).sign_()
|
|
14
|
+
exp_avg_.lerp_(tensors, 1-beta2)
|
|
10
15
|
return update
|
|
11
16
|
|
|
12
|
-
|
|
17
|
+
|
|
18
|
+
class Lion(Transform):
|
|
13
19
|
"""Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
|
|
14
20
|
|
|
15
21
|
Args:
|
|
@@ -19,10 +25,11 @@ class Lion(OptimizerModule):
|
|
|
19
25
|
|
|
20
26
|
def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
|
|
21
27
|
defaults = dict(beta1=beta1, beta2=beta2)
|
|
22
|
-
super().__init__(defaults)
|
|
28
|
+
super().__init__(defaults, uses_grad=False)
|
|
23
29
|
|
|
24
30
|
@torch.no_grad
|
|
25
|
-
def
|
|
26
|
-
beta1, beta2 = self.
|
|
27
|
-
|
|
28
|
-
return
|
|
31
|
+
def transform(self, tensors, params, grads, vars):
|
|
32
|
+
beta1, beta2 = self.get_settings('beta1', 'beta2', params = params, cls=NumberList)
|
|
33
|
+
exp_avg = self.get_state('ema', params=params, cls=TensorList)
|
|
34
|
+
return lion_(TensorList(tensors),exp_avg,beta1,beta2)
|
|
35
|
+
|