torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,570 +0,0 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
from ...modules import (
|
|
5
|
-
LR,
|
|
6
|
-
AddNoise,
|
|
7
|
-
Centralize,
|
|
8
|
-
Grad,
|
|
9
|
-
HeavyBall,
|
|
10
|
-
LineSearches, LaplacianSmoothing,
|
|
11
|
-
NesterovMomentum,
|
|
12
|
-
Normalize,
|
|
13
|
-
Random,
|
|
14
|
-
Sign,
|
|
15
|
-
UseGradSign,
|
|
16
|
-
WeightDecay,
|
|
17
|
-
get_line_search,
|
|
18
|
-
)
|
|
19
|
-
from ...modules import SGD as _SGD
|
|
20
|
-
from ...modules import Adagrad as _Adagrad
|
|
21
|
-
from ...modules import Adam as _Adam
|
|
22
|
-
from ...modules import Lion as _Lion
|
|
23
|
-
from ...modules import RMSProp as _RMSProp
|
|
24
|
-
from ...modules import Rprop as _Rprop
|
|
25
|
-
from ...random.random import Distributions
|
|
26
|
-
from ..modular import Modular
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class GD(Modular):
|
|
30
|
-
"""Gradient descent with armijo line search.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
34
|
-
lr (float): learning rate (default: 1).
|
|
35
|
-
line_search (LineSearches | None, optional):
|
|
36
|
-
line search type. Defaults to 'armijo'.
|
|
37
|
-
"""
|
|
38
|
-
def __init__(
|
|
39
|
-
self,
|
|
40
|
-
params,
|
|
41
|
-
lr: float = 1,
|
|
42
|
-
line_search: LineSearches | None = 'armijo',
|
|
43
|
-
):
|
|
44
|
-
modules: list = [LR(lr)]
|
|
45
|
-
if line_search is not None: modules.append(get_line_search(line_search))
|
|
46
|
-
|
|
47
|
-
super().__init__(params, *modules)
|
|
48
|
-
|
|
49
|
-
class SGD(Modular):
|
|
50
|
-
"""Exactly matches `torch.optim.SGD`, except
|
|
51
|
-
nesterov momentum additionally supports dampening, negative momentum is allowed,
|
|
52
|
-
and weight decay supports decoupling.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
56
|
-
lr (float): learning rate (default: 1e-3).
|
|
57
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
58
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
59
|
-
nesterov (bool, optional):
|
|
60
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
61
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
62
|
-
decoupled (bool, optional):
|
|
63
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
64
|
-
"""
|
|
65
|
-
def __init__(
|
|
66
|
-
self,
|
|
67
|
-
params,
|
|
68
|
-
lr: float = 1e-3,
|
|
69
|
-
momentum: float = 0,
|
|
70
|
-
dampening: float = 0,
|
|
71
|
-
nesterov: bool = False,
|
|
72
|
-
weight_decay: float = 0,
|
|
73
|
-
decoupled=False,
|
|
74
|
-
):
|
|
75
|
-
modules: list = [
|
|
76
|
-
_SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
|
|
77
|
-
LR(lr)
|
|
78
|
-
]
|
|
79
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
80
|
-
super().__init__(params, modules)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class SignSGD(Modular):
|
|
84
|
-
"""SGD that uses sign of the gradient, can act as a normalizer and improve stability.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
88
|
-
lr (float): learning rate (default: 1e-3).
|
|
89
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
90
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
91
|
-
nesterov (bool, optional):
|
|
92
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
93
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
94
|
-
decoupled (bool, optional):
|
|
95
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
96
|
-
"""
|
|
97
|
-
def __init__(
|
|
98
|
-
self,
|
|
99
|
-
params,
|
|
100
|
-
lr: float = 1e-3,
|
|
101
|
-
momentum: float = 0,
|
|
102
|
-
dampening: float = 0,
|
|
103
|
-
nesterov: bool = False,
|
|
104
|
-
weight_decay: float = 0,
|
|
105
|
-
decoupled=False,
|
|
106
|
-
):
|
|
107
|
-
modules: list = [
|
|
108
|
-
Sign(),
|
|
109
|
-
_SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
|
|
110
|
-
LR(lr),
|
|
111
|
-
]
|
|
112
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
113
|
-
super().__init__(params, modules)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
class NormSGD(Modular):
|
|
117
|
-
"""SGD with gradient normalization and optionally centralization.
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
121
|
-
lr (float):
|
|
122
|
-
learning rate, gradients are normalized to this value.
|
|
123
|
-
This can typically be 10 times bigger than normal SGD (default: 1e-1).
|
|
124
|
-
centralize (bool, optional): whether to centralize gradients (default: True).
|
|
125
|
-
norm_mode (str, optional):
|
|
126
|
-
what to normalize.
|
|
127
|
-
|
|
128
|
-
- "global": normalize the entire gradient, as if it was a single vector.
|
|
129
|
-
|
|
130
|
-
- "param": normalize each param's gradient.
|
|
131
|
-
|
|
132
|
-
- "channel": normalize gradient of each channel of each param (default).
|
|
133
|
-
centralize_mode (str, optional): what to centralize (same options as `norm_mode`). Defaults to 'channel'.
|
|
134
|
-
min_numel (int, optional):
|
|
135
|
-
skips parameters with less than this many elements. This avoids the issue where
|
|
136
|
-
parameters that have a single element always get set to the value of 1.
|
|
137
|
-
Ignored when mode is 'global'. Defaults to 2.
|
|
138
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
139
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
140
|
-
nesterov (bool, optional):
|
|
141
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
142
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
143
|
-
decoupled (bool, optional):
|
|
144
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
145
|
-
"""
|
|
146
|
-
def __init__(
|
|
147
|
-
self,
|
|
148
|
-
params,
|
|
149
|
-
lr: float = 1e-1,
|
|
150
|
-
normalize=True,
|
|
151
|
-
norm_mode: Literal["global", "param", "channel"] = 'channel',
|
|
152
|
-
ord = 2,
|
|
153
|
-
centralize=True,
|
|
154
|
-
centralize_mode: Literal["global", "param", "channel"] = 'channel',
|
|
155
|
-
min_numel=2,
|
|
156
|
-
momentum: float = 0,
|
|
157
|
-
dampening: float = 0,
|
|
158
|
-
nesterov: bool = False,
|
|
159
|
-
weight_decay: float = 0,
|
|
160
|
-
decoupled=True,
|
|
161
|
-
):
|
|
162
|
-
modules: list = [
|
|
163
|
-
_SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
|
|
164
|
-
LR(lr),
|
|
165
|
-
]
|
|
166
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
167
|
-
if normalize: modules.insert(0, Normalize(1, mode=norm_mode, min_numel=min_numel, ord=ord))
|
|
168
|
-
if centralize: modules.insert(0, Centralize(centralize_mode, min_numel=min_numel))
|
|
169
|
-
super().__init__(params, modules)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
class NoisySGD(Modular):
|
|
173
|
-
"""SGD with noise added to gradients. The formula for noise magnitude is `alpha * mean(abs(grad))`.
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
177
|
-
lr (float): learning rate (default: 1e-3)
|
|
178
|
-
alpha (float, optional): magnitude of noise. Defaults to 1e-2.
|
|
179
|
-
distribution (Distributions, optional): distribution of noise. Defaults to 'normal'.
|
|
180
|
-
mode (str, optional):
|
|
181
|
-
how to calculate noise magnitude.
|
|
182
|
-
|
|
183
|
-
- "absolute": ignores gradient magnitude and always uses `alpha` as magnitude.
|
|
184
|
-
|
|
185
|
-
- "global": multiplies `alpha` by mean of the entire gradient, as if it was a single vector.
|
|
186
|
-
|
|
187
|
-
- "param": multiplies `alpha` by mean of each individual parameter (default).
|
|
188
|
-
|
|
189
|
-
- "channel": multiplies `alpha` by mean of each channel of each parameter.
|
|
190
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
191
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
192
|
-
nesterov (bool, optional):
|
|
193
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
194
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
195
|
-
decoupled (bool, optional):
|
|
196
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
197
|
-
"""
|
|
198
|
-
def __init__(
|
|
199
|
-
self,
|
|
200
|
-
params,
|
|
201
|
-
lr: float = 1e-3,
|
|
202
|
-
alpha: float = 1,
|
|
203
|
-
distribution: Distributions = 'normal',
|
|
204
|
-
mode: Literal["absolute", "global", "param", "channel"] = "param",
|
|
205
|
-
momentum: float = 0,
|
|
206
|
-
dampening: float = 0,
|
|
207
|
-
nesterov: bool = False,
|
|
208
|
-
weight_decay: float = 0,
|
|
209
|
-
decoupled=False,
|
|
210
|
-
):
|
|
211
|
-
|
|
212
|
-
modules: list = [
|
|
213
|
-
_SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
|
|
214
|
-
AddNoise(alpha, distribution, mode),
|
|
215
|
-
LR(lr),
|
|
216
|
-
]
|
|
217
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
218
|
-
super().__init__(params, modules)
|
|
219
|
-
|
|
220
|
-
class LaplacianSmoothingSGD(Modular):
|
|
221
|
-
"""SGD with laplacian smoothing.
|
|
222
|
-
|
|
223
|
-
Args:
|
|
224
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
225
|
-
lr (float): learning rate (default: 1e-3)
|
|
226
|
-
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
227
|
-
layerwise (bool, optional):
|
|
228
|
-
If True, applies smoothing to each parameter's gradient separately,
|
|
229
|
-
Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
|
|
230
|
-
min_numel (int, optional):
|
|
231
|
-
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
232
|
-
Only has effect if `layerwise` is True. Defaults to 4.
|
|
233
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
234
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
235
|
-
nesterov (bool, optional):
|
|
236
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
237
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
238
|
-
decoupled (bool, optional):
|
|
239
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
240
|
-
|
|
241
|
-
Reference:
|
|
242
|
-
*Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
|
|
243
|
-
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
244
|
-
"""
|
|
245
|
-
def __init__(
|
|
246
|
-
self,
|
|
247
|
-
params,
|
|
248
|
-
lr: float = 1e-3,
|
|
249
|
-
sigma: float = 1,
|
|
250
|
-
layerwise: bool = True,
|
|
251
|
-
min_numel: int = 4,
|
|
252
|
-
momentum: float = 0,
|
|
253
|
-
dampening: float = 0,
|
|
254
|
-
nesterov: bool = False,
|
|
255
|
-
weight_decay: float = 0,
|
|
256
|
-
decoupled=False,
|
|
257
|
-
):
|
|
258
|
-
|
|
259
|
-
modules: list = [
|
|
260
|
-
LaplacianSmoothing(sigma=sigma, layerwise=layerwise,min_numel=min_numel),
|
|
261
|
-
_SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
262
|
-
LR(lr),
|
|
263
|
-
]
|
|
264
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
265
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
266
|
-
super().__init__(params, modules)
|
|
267
|
-
|
|
268
|
-
class Adagrad(Modular):
|
|
269
|
-
"""Divides ascent direction by mean square root of the sum of all past ascent directions.
|
|
270
|
-
|
|
271
|
-
Exactly matches `torch.optim.Adagrad`.
|
|
272
|
-
|
|
273
|
-
Args:
|
|
274
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
275
|
-
lr (float): learning rate (default: 1e-3).
|
|
276
|
-
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
277
|
-
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
278
|
-
eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-10.
|
|
279
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
280
|
-
decoupled (bool, optional):
|
|
281
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
282
|
-
"""
|
|
283
|
-
|
|
284
|
-
def __init__(
|
|
285
|
-
self,
|
|
286
|
-
params,
|
|
287
|
-
lr: float = 1e-3,
|
|
288
|
-
lr_decay: float = 0,
|
|
289
|
-
initial_accumulator_value: float = 0,
|
|
290
|
-
eps: float = 1e-10,
|
|
291
|
-
weight_decay: float = 0,
|
|
292
|
-
decoupled=False,
|
|
293
|
-
):
|
|
294
|
-
modules: list = [
|
|
295
|
-
_Adagrad(lr_decay = lr_decay, initial_accumulator_value = initial_accumulator_value, eps = eps),
|
|
296
|
-
LR(lr),
|
|
297
|
-
]
|
|
298
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
299
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
300
|
-
super().__init__(params, modules)
|
|
301
|
-
|
|
302
|
-
class Rprop(Modular):
|
|
303
|
-
"""
|
|
304
|
-
Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
|
|
305
|
-
or `nminus` if it did. Then the update is applied with the sign of the current gradient.
|
|
306
|
-
|
|
307
|
-
Additionally, if gradient changes sign, the update for that weight is reverted.
|
|
308
|
-
Next step, magnitude for that weight won't change.
|
|
309
|
-
|
|
310
|
-
Compared to pytorch this also implements backtracking update when sign changes.
|
|
311
|
-
To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
|
|
312
|
-
|
|
313
|
-
Args:
|
|
314
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
315
|
-
lr (float): learning rate (default: 1e-3).
|
|
316
|
-
nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
|
|
317
|
-
nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
|
|
318
|
-
lb (float): minimum step size, can be None (default: 1e-6)
|
|
319
|
-
ub (float): maximum step size, can be None (default: 50)
|
|
320
|
-
backtrack (float):
|
|
321
|
-
if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
|
|
322
|
-
When this is False, this exactly matches pytorch Rprop. (default: True)
|
|
323
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
324
|
-
decoupled (bool, optional):
|
|
325
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
326
|
-
|
|
327
|
-
reference
|
|
328
|
-
*Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
|
|
329
|
-
The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
|
|
330
|
-
"""
|
|
331
|
-
def __init__(
|
|
332
|
-
self,
|
|
333
|
-
params,
|
|
334
|
-
lr: float = 1e-3,
|
|
335
|
-
nplus: float = 1.2,
|
|
336
|
-
nminus: float = 0.5,
|
|
337
|
-
lb: float | None = 1e-6,
|
|
338
|
-
ub: float | None = 50,
|
|
339
|
-
backtrack=True,
|
|
340
|
-
weight_decay: float = 0,
|
|
341
|
-
decoupled=False,
|
|
342
|
-
):
|
|
343
|
-
modules: list = [
|
|
344
|
-
_Rprop(nplus = nplus, nminus = nminus, lb=lb, ub = ub, backtrack=backtrack),
|
|
345
|
-
LR(lr),
|
|
346
|
-
]
|
|
347
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
348
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
349
|
-
super().__init__(params, modules)
|
|
350
|
-
|
|
351
|
-
class RMSProp(Modular):
|
|
352
|
-
"""
|
|
353
|
-
Divides ascent direction by running average of its mean square root.
|
|
354
|
-
|
|
355
|
-
Exactly matches `torch.optim.RMSProp`, except momentum initialization is arbitrarily different.
|
|
356
|
-
|
|
357
|
-
Args:
|
|
358
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
359
|
-
lr (float): learning rate (default: 1e-3).
|
|
360
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
361
|
-
alpha (float, optional):
|
|
362
|
-
smoothing constant (decay of ascent mean square root running average).
|
|
363
|
-
Defaults to 0.99.
|
|
364
|
-
eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-8.
|
|
365
|
-
centered (float, optional):
|
|
366
|
-
if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance.
|
|
367
|
-
Defaults to False.
|
|
368
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
369
|
-
nesterov (bool, optional):
|
|
370
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
371
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
372
|
-
decoupled (bool, optional):
|
|
373
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
374
|
-
|
|
375
|
-
reference
|
|
376
|
-
https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
|
|
377
|
-
"""
|
|
378
|
-
def __init__(
|
|
379
|
-
self,
|
|
380
|
-
params,
|
|
381
|
-
lr: float = 1e-2,
|
|
382
|
-
momentum: float = 0,
|
|
383
|
-
alpha: float = 0.99,
|
|
384
|
-
eps: float = 1e-8,
|
|
385
|
-
centered: bool = False,
|
|
386
|
-
nesterov = False,
|
|
387
|
-
dampening: float = 0,
|
|
388
|
-
weight_decay: float = 0,
|
|
389
|
-
decoupled=False,
|
|
390
|
-
):
|
|
391
|
-
modules: list = [
|
|
392
|
-
_RMSProp(smoothing = alpha, eps = eps, centered = centered,),
|
|
393
|
-
_SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
394
|
-
LR(lr),
|
|
395
|
-
]
|
|
396
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
397
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
398
|
-
super().__init__(params, modules)
|
|
399
|
-
|
|
400
|
-
class Adam(Modular):
|
|
401
|
-
"""Adam. Combines momentum and RMSProp. Exactly matches `torch.optim.Adam`, except
|
|
402
|
-
if `decoupled` is True, weight decay is truly decoupled and doesn't depend on LR.
|
|
403
|
-
|
|
404
|
-
Args:
|
|
405
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
406
|
-
lr (float): learning rate (default: 1e-3).
|
|
407
|
-
beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
|
|
408
|
-
beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
|
|
409
|
-
eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
|
|
410
|
-
amsgrad (bool, optional):
|
|
411
|
-
whether to use the AMSGrad variant of this algorithm from
|
|
412
|
-
On the Convergence of Adam and Beyond (default: False).
|
|
413
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
414
|
-
decoupled (bool, optional):
|
|
415
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
416
|
-
"""
|
|
417
|
-
def __init__(
|
|
418
|
-
self,
|
|
419
|
-
params,
|
|
420
|
-
lr: float = 1e-3,
|
|
421
|
-
beta1: float = 0.9,
|
|
422
|
-
beta2: float = 0.999,
|
|
423
|
-
eps: float = 1e-8,
|
|
424
|
-
amsgrad=False,
|
|
425
|
-
weight_decay: float = 0,
|
|
426
|
-
decoupled=True,
|
|
427
|
-
):
|
|
428
|
-
modules: list = [
|
|
429
|
-
_Adam(beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
|
|
430
|
-
LR(lr),
|
|
431
|
-
]
|
|
432
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
433
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
434
|
-
super().__init__(params, modules)
|
|
435
|
-
|
|
436
|
-
class AdamW(Adam):
|
|
437
|
-
"""AdamW. Combines momentum and RMSProp. Exactly matches `torch.optim.Adam`, except
|
|
438
|
-
if `decoupled` is True, weight decay is truly decoupled and doesn't depend on LR.
|
|
439
|
-
|
|
440
|
-
Args:
|
|
441
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
442
|
-
lr (float): learning rate (default: 1e-3).
|
|
443
|
-
beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
|
|
444
|
-
beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
|
|
445
|
-
eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
|
|
446
|
-
amsgrad (bool, optional):
|
|
447
|
-
whether to use the AMSGrad variant of this algorithm from
|
|
448
|
-
On the Convergence of Adam and Beyond (default: False).
|
|
449
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.01.
|
|
450
|
-
decoupled (bool, optional):
|
|
451
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
452
|
-
"""
|
|
453
|
-
def __init__(
|
|
454
|
-
self,
|
|
455
|
-
params,
|
|
456
|
-
lr: float = 1e-3,
|
|
457
|
-
beta1: float = 0.9,
|
|
458
|
-
beta2: float = 0.999,
|
|
459
|
-
eps: float = 1e-8,
|
|
460
|
-
amsgrad=False,
|
|
461
|
-
weight_decay: float = 1e-2,
|
|
462
|
-
decoupled=True,
|
|
463
|
-
):
|
|
464
|
-
super().__init__(params=params,lr=lr,beta1=beta1,beta2=beta2,eps=eps,amsgrad=amsgrad,weight_decay=weight_decay,decoupled=decoupled)
|
|
465
|
-
|
|
466
|
-
class Grams(Modular):
|
|
467
|
-
"""Grams (Gradient Descent with Adaptive Momentum Scaling) from https://arxiv.org/abs/2412.17107v1.
|
|
468
|
-
This is Adam but uses gradient sign.
|
|
469
|
-
Args:
|
|
470
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
471
|
-
lr (float): learning rate (default: 1e-3).
|
|
472
|
-
beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
|
|
473
|
-
beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
|
|
474
|
-
eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
|
|
475
|
-
amsgrad (bool, optional):
|
|
476
|
-
whether to use the AMSGrad variant of this algorithm from
|
|
477
|
-
On the Convergence of Adam and Beyond (default: False).
|
|
478
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
479
|
-
decoupled (bool, optional):
|
|
480
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
481
|
-
"""
|
|
482
|
-
def __init__(
|
|
483
|
-
self,
|
|
484
|
-
params,
|
|
485
|
-
lr: float = 1e-3,
|
|
486
|
-
beta1: float = 0.9,
|
|
487
|
-
beta2: float = 0.999,
|
|
488
|
-
eps: float = 1e-8,
|
|
489
|
-
amsgrad=False,
|
|
490
|
-
weight_decay: float = 0,
|
|
491
|
-
decoupled=True,
|
|
492
|
-
):
|
|
493
|
-
modules: list = [
|
|
494
|
-
_Adam(beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
|
|
495
|
-
LR(lr),
|
|
496
|
-
UseGradSign()
|
|
497
|
-
]
|
|
498
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
499
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
500
|
-
super().__init__(params, modules)
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
class Lion(Modular):
|
|
504
|
-
"""Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
|
|
505
|
-
|
|
506
|
-
Args:
|
|
507
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
508
|
-
lr (float): learning rate (default: 1e-3).
|
|
509
|
-
beta1 (float, optional): dampening for momentum. Defaults to 0.9.
|
|
510
|
-
beta2 (float, optional): momentum factor. Defaults to 0.99.
|
|
511
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
512
|
-
decoupled (bool, optional):
|
|
513
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
514
|
-
"""
|
|
515
|
-
def __init__(
|
|
516
|
-
self,
|
|
517
|
-
params,
|
|
518
|
-
lr: float = 1e-3,
|
|
519
|
-
beta1: float = 0.9,
|
|
520
|
-
beta2: float = 0.99,
|
|
521
|
-
weight_decay: float = 0,
|
|
522
|
-
decoupled=True,
|
|
523
|
-
):
|
|
524
|
-
modules: list = [
|
|
525
|
-
_Lion(beta1, beta2),
|
|
526
|
-
LR(lr)
|
|
527
|
-
]
|
|
528
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
529
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
530
|
-
super().__init__(params, modules)
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
class NestedNesterov(Modular):
|
|
535
|
-
"""Chains multiple nesterov momentums. The default (0.5, 0.5) seems to work well.
|
|
536
|
-
|
|
537
|
-
Args:
|
|
538
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
539
|
-
lr (float): learning rate (default: 1e-3).
|
|
540
|
-
momentums (Iterable[float], optional): sequence of momentums. Defaults to (0.5, 0.5, 0.5).
|
|
541
|
-
dampening (float | Iterable[float], optional):
|
|
542
|
-
sequence of dampenings for each momentum, or a single float that is used
|
|
543
|
-
for all momentums. Defaults to 0.
|
|
544
|
-
nesterov (bool, optional):
|
|
545
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to True.
|
|
546
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
547
|
-
decoupled (bool, optional):
|
|
548
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
549
|
-
|
|
550
|
-
"""
|
|
551
|
-
def __init__(
|
|
552
|
-
self,
|
|
553
|
-
params,
|
|
554
|
-
lr: float = 1e-3,
|
|
555
|
-
momentums: Iterable[float] = (0.5, 0.5, 0.5),
|
|
556
|
-
dampening: float | Iterable[float] = 0,
|
|
557
|
-
nesterov=True,
|
|
558
|
-
weight_decay: float = 0,
|
|
559
|
-
decoupled=True,
|
|
560
|
-
):
|
|
561
|
-
momentums = list(momentums)
|
|
562
|
-
if isinstance(dampening, (int, float)): dampening = [dampening for _ in momentums]
|
|
563
|
-
|
|
564
|
-
cls = NesterovMomentum if nesterov else HeavyBall
|
|
565
|
-
modules: list = [cls(m, d) for m, d in zip(momentums, dampening)] + [LR(lr)]
|
|
566
|
-
|
|
567
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
568
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
569
|
-
|
|
570
|
-
super().__init__(params, modules)
|