torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,343 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
|
|
3
|
-
from ...modules import (
|
|
4
|
-
LR,
|
|
5
|
-
SGD,
|
|
6
|
-
Abs,
|
|
7
|
-
Adam,
|
|
8
|
-
Add,
|
|
9
|
-
AddMagnitude,
|
|
10
|
-
Cautious,
|
|
11
|
-
Div,
|
|
12
|
-
Divide,
|
|
13
|
-
Grad,
|
|
14
|
-
HeavyBall,
|
|
15
|
-
Interpolate,
|
|
16
|
-
Lerp,
|
|
17
|
-
Multistep,
|
|
18
|
-
NanToNum,
|
|
19
|
-
NesterovMomentum,
|
|
20
|
-
Normalize,
|
|
21
|
-
Random,
|
|
22
|
-
RDiv,
|
|
23
|
-
Reciprocal,
|
|
24
|
-
UseGradSign,
|
|
25
|
-
WeightDecay,
|
|
26
|
-
)
|
|
27
|
-
from ...modules import RandomCoordinateMomentum as _RandomCoordinateMomentum
|
|
28
|
-
from ...modules.experimental import GradMin as _GradMin
|
|
29
|
-
from ...modules.experimental import (
|
|
30
|
-
HVPDiagNewton as _HVPDiagNewton,
|
|
31
|
-
)
|
|
32
|
-
from ...modules.experimental import MinibatchRprop as _MinibatchRprop
|
|
33
|
-
from ...modules.experimental import ReduceOutwardLR
|
|
34
|
-
from ...random import Distributions
|
|
35
|
-
from ..modular import Modular
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class HVPDiagNewton(Modular):
|
|
39
|
-
"""for experiments, unlikely to work well on most problems.
|
|
40
|
-
|
|
41
|
-
explanation - this should approximate newton method with 2 backward passes, but only if hessian is purely diagonal"""
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
params,
|
|
45
|
-
lr: float = 1e-1,
|
|
46
|
-
eps: float = 1e-2,
|
|
47
|
-
):
|
|
48
|
-
modules = [_HVPDiagNewton(eps = eps), LR(lr)]
|
|
49
|
-
super().__init__(params, modules)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class ReciprocalSGD(Modular):
|
|
53
|
-
"""for experiments, unlikely to work well on most problems.
|
|
54
|
-
|
|
55
|
-
explanation - this basically uses normalized *1 / (gradient + eps)*."""
|
|
56
|
-
def __init__(
|
|
57
|
-
self,
|
|
58
|
-
params,
|
|
59
|
-
lr: float = 1e-2,
|
|
60
|
-
eps: float = 1e-2,
|
|
61
|
-
momentum: float = 0,
|
|
62
|
-
dampening: float = 0,
|
|
63
|
-
nesterov: bool = False,
|
|
64
|
-
weight_decay: float = 0,
|
|
65
|
-
decoupled=True,
|
|
66
|
-
):
|
|
67
|
-
modules: list = [
|
|
68
|
-
AddMagnitude(eps, add_to_zero=False),
|
|
69
|
-
Reciprocal(),
|
|
70
|
-
NanToNum(0,0,0),
|
|
71
|
-
Normalize(1),
|
|
72
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
73
|
-
LR(lr),
|
|
74
|
-
]
|
|
75
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
76
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
77
|
-
|
|
78
|
-
super().__init__(params, modules)
|
|
79
|
-
|
|
80
|
-
class NoiseSign(Modular):
|
|
81
|
-
"""for experiments, unlikely to work well on most problems.
|
|
82
|
-
|
|
83
|
-
explanation - uses random vector with gradient sign, and works quite well despite being completely random."""
|
|
84
|
-
def __init__(
|
|
85
|
-
self,
|
|
86
|
-
params,
|
|
87
|
-
lr: float = 1e-2,
|
|
88
|
-
distribution: Distributions = 'normal',
|
|
89
|
-
momentum: float = 0,
|
|
90
|
-
dampening: float = 0,
|
|
91
|
-
nesterov: bool = False,
|
|
92
|
-
weight_decay: float = 0,
|
|
93
|
-
decoupled=True,
|
|
94
|
-
):
|
|
95
|
-
modules: list = [
|
|
96
|
-
Random(1, distribution),
|
|
97
|
-
UseGradSign(),
|
|
98
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
99
|
-
LR(lr),
|
|
100
|
-
]
|
|
101
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
102
|
-
else: modules.insert(2, WeightDecay(weight_decay))
|
|
103
|
-
|
|
104
|
-
super().__init__(params, modules)
|
|
105
|
-
|
|
106
|
-
class MomentumNumerator(Modular):
|
|
107
|
-
"""for experiments, unlikely to work well on most problems. (somewhat promising)
|
|
108
|
-
|
|
109
|
-
explanation - momentum divided by gradient."""
|
|
110
|
-
def __init__(
|
|
111
|
-
self,
|
|
112
|
-
params,
|
|
113
|
-
lr: float = 1e-2,
|
|
114
|
-
momentum: float = 0.9,
|
|
115
|
-
nesterov: bool = True,
|
|
116
|
-
eps: float = 1e-2,
|
|
117
|
-
weight_decay: float = 0,
|
|
118
|
-
decoupled=True, ):
|
|
119
|
-
|
|
120
|
-
modules: list = [
|
|
121
|
-
Divide(
|
|
122
|
-
numerator = SGD(momentum = momentum, nesterov=nesterov),
|
|
123
|
-
denominator=[Abs(), Add(eps)]
|
|
124
|
-
),
|
|
125
|
-
Normalize(),
|
|
126
|
-
LR(lr),
|
|
127
|
-
]
|
|
128
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
129
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
130
|
-
super().__init__(params, modules)
|
|
131
|
-
|
|
132
|
-
class MomentumDenominator(Modular):
|
|
133
|
-
"""for experiments, unlikely to work well on most problems.
|
|
134
|
-
|
|
135
|
-
explanation - gradient divided by normalized momentum."""
|
|
136
|
-
def __init__(
|
|
137
|
-
self,
|
|
138
|
-
params,
|
|
139
|
-
lr: float = 1e-2,
|
|
140
|
-
momentum: float = 0.9,
|
|
141
|
-
nesterov: bool = True,
|
|
142
|
-
eps: float = 1e-2,
|
|
143
|
-
weight_decay: float = 0,
|
|
144
|
-
decoupled=True,
|
|
145
|
-
):
|
|
146
|
-
modules: list = [
|
|
147
|
-
Div([SGD(momentum=momentum, nesterov=nesterov), Abs(), Add(eps), Normalize(1)]),
|
|
148
|
-
Normalize(),
|
|
149
|
-
LR(lr),
|
|
150
|
-
]
|
|
151
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
152
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
153
|
-
super().__init__(params, modules)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
class ExaggeratedNesterov(Modular):
|
|
157
|
-
"""for experiments, unlikely to work well on most problems.
|
|
158
|
-
|
|
159
|
-
explanation - exaggerates difference between heavyball and nesterov momentum."""
|
|
160
|
-
def __init__(
|
|
161
|
-
self,
|
|
162
|
-
params,
|
|
163
|
-
lr: float = 1e-2,
|
|
164
|
-
momentum: float = 0.9,
|
|
165
|
-
dampening: float = 0,
|
|
166
|
-
strength: float = 5,
|
|
167
|
-
weight_decay: float = 0,
|
|
168
|
-
decoupled=True,
|
|
169
|
-
):
|
|
170
|
-
|
|
171
|
-
modules: list = [
|
|
172
|
-
Interpolate(HeavyBall(momentum, dampening), NesterovMomentum(momentum, dampening), strength),
|
|
173
|
-
LR(lr),
|
|
174
|
-
]
|
|
175
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
176
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
177
|
-
super().__init__(params, modules)
|
|
178
|
-
|
|
179
|
-
class ExtraCautiousAdam(Modular):
|
|
180
|
-
"""for experiments, unlikely to work well on most problems.
|
|
181
|
-
|
|
182
|
-
explanation - caution with true backtracking."""
|
|
183
|
-
def __init__(
|
|
184
|
-
self,
|
|
185
|
-
params,
|
|
186
|
-
lr: float = 1,
|
|
187
|
-
beta1: float = 0.9,
|
|
188
|
-
beta2: float = 0.999,
|
|
189
|
-
eps: float = 1e-8,
|
|
190
|
-
amsgrad=False,
|
|
191
|
-
normalize = False,
|
|
192
|
-
c_eps = 1e-6,
|
|
193
|
-
mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
|
|
194
|
-
strength = 5,
|
|
195
|
-
weight_decay: float = 0,
|
|
196
|
-
decoupled=True,
|
|
197
|
-
):
|
|
198
|
-
modules: list = [
|
|
199
|
-
Adam(beta1, beta2, eps, amsgrad=amsgrad),
|
|
200
|
-
Lerp(Cautious(normalize, c_eps, mode), strength),
|
|
201
|
-
LR(lr),
|
|
202
|
-
]
|
|
203
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
204
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
205
|
-
super().__init__(params, modules)
|
|
206
|
-
|
|
207
|
-
class InwardSGD(Modular):
|
|
208
|
-
"""for experiments, unlikely to work well on most problems.
|
|
209
|
-
|
|
210
|
-
explanation - reduces lrs for updates that move weights away from 0."""
|
|
211
|
-
def __init__(
|
|
212
|
-
self,
|
|
213
|
-
params,
|
|
214
|
-
lr: float = 1e-3,
|
|
215
|
-
momentum: float = 0,
|
|
216
|
-
dampening: float = 0,
|
|
217
|
-
nesterov: bool = False,
|
|
218
|
-
mul = 0.5,
|
|
219
|
-
use_grad=False,
|
|
220
|
-
invert=False,
|
|
221
|
-
weight_decay: float = 0,
|
|
222
|
-
decoupled=True,
|
|
223
|
-
):
|
|
224
|
-
modules: list = [
|
|
225
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
226
|
-
LR(lr),
|
|
227
|
-
ReduceOutwardLR(mul, use_grad, invert),
|
|
228
|
-
]
|
|
229
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
230
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
231
|
-
super().__init__(params, modules)
|
|
232
|
-
|
|
233
|
-
class MultistepSGD(Modular):
|
|
234
|
-
"""for experiments, unlikely to work well on most problems.
|
|
235
|
-
|
|
236
|
-
explanation - perform multiple steps per batch. Momentum applies to the total update over multiple step"""
|
|
237
|
-
def __init__(
|
|
238
|
-
self,
|
|
239
|
-
params,
|
|
240
|
-
lr: float = 1e-3,
|
|
241
|
-
momentum: float = 0,
|
|
242
|
-
dampening: float = 0,
|
|
243
|
-
nesterov: bool = False,
|
|
244
|
-
num_steps=2,
|
|
245
|
-
weight_decay: float = 0,
|
|
246
|
-
decoupled=True,
|
|
247
|
-
):
|
|
248
|
-
# lr, lr_module = _get_baked_in_and_module_lr(lr, kwargs) # multistep must use lr
|
|
249
|
-
|
|
250
|
-
modules: list = [
|
|
251
|
-
Multistep(LR(lr), num_steps=num_steps),
|
|
252
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
253
|
-
]
|
|
254
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
255
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
256
|
-
super().__init__(params, modules)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
class MinibatchRprop(Modular):
|
|
260
|
-
"""
|
|
261
|
-
for experiments, unlikely to work well on most problems.
|
|
262
|
-
|
|
263
|
-
explanation: does 2 steps per batch, applies rprop rule on the second step.
|
|
264
|
-
"""
|
|
265
|
-
def __init__(
|
|
266
|
-
self,
|
|
267
|
-
params,
|
|
268
|
-
lr: float = 1,
|
|
269
|
-
nplus: float = 1.2,
|
|
270
|
-
nminus: float = 0.5,
|
|
271
|
-
lb: float | None = 1e-6,
|
|
272
|
-
ub: float | None = 50,
|
|
273
|
-
backtrack=True,
|
|
274
|
-
next_mode = 'continue',
|
|
275
|
-
increase_mul = 0.5,
|
|
276
|
-
weight_decay: float = 0,
|
|
277
|
-
decoupled=True,
|
|
278
|
-
):
|
|
279
|
-
modules: list = [
|
|
280
|
-
_MinibatchRprop(nplus=nplus,nminus=nminus,lb=lb,ub=ub,backtrack=backtrack,next_mode=next_mode,increase_mul=increase_mul),
|
|
281
|
-
LR(lr),
|
|
282
|
-
]
|
|
283
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
284
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
285
|
-
super().__init__(params, modules)
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
class RandomCoordinateMomentum(Modular):
|
|
289
|
-
"""for experiments, unlikely to work well on most problems.
|
|
290
|
-
|
|
291
|
-
Only uses `p` random coordinates of the new update. Other coordinates remain from previous update.
|
|
292
|
-
This works but I don't know if it is any good.
|
|
293
|
-
|
|
294
|
-
Args:
|
|
295
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
296
|
-
lr (float): learning rate (default: 1e-3).
|
|
297
|
-
p (float, optional): probability to update velocity with a new weigh value. Defaults to 0.1.
|
|
298
|
-
nesterov (bool, optional): if False, update uses delayed momentum. Defaults to True.
|
|
299
|
-
|
|
300
|
-
"""
|
|
301
|
-
|
|
302
|
-
def __init__(
|
|
303
|
-
self,
|
|
304
|
-
params,
|
|
305
|
-
lr: float = 1e-3,
|
|
306
|
-
p: float = 0.1,
|
|
307
|
-
nesterov: bool = True,
|
|
308
|
-
weight_decay: float = 0,
|
|
309
|
-
decoupled=True,
|
|
310
|
-
):
|
|
311
|
-
modules: list = [_RandomCoordinateMomentum(p, nesterov), LR(lr)]
|
|
312
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
313
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
314
|
-
super().__init__(params, modules)
|
|
315
|
-
|
|
316
|
-
class GradMin(Modular):
|
|
317
|
-
"""for experiments, unlikely to work well on most problems.
|
|
318
|
-
|
|
319
|
-
explanation - this uses gradient wrt sum of gradients + loss."""
|
|
320
|
-
|
|
321
|
-
def __init__(
|
|
322
|
-
self,
|
|
323
|
-
params,
|
|
324
|
-
lr: float = 1e-2,
|
|
325
|
-
loss_term: float = 1,
|
|
326
|
-
square: bool = False,
|
|
327
|
-
maximize_grad: bool = False,
|
|
328
|
-
momentum: float = 0,
|
|
329
|
-
dampening: float = 0,
|
|
330
|
-
nesterov: bool = False,
|
|
331
|
-
weight_decay: float = 0,
|
|
332
|
-
decoupled=True,
|
|
333
|
-
):
|
|
334
|
-
modules: list = [
|
|
335
|
-
_GradMin(loss_term, square, maximize_grad),
|
|
336
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
337
|
-
LR(lr),
|
|
338
|
-
]
|
|
339
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
340
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
341
|
-
super().__init__(params, modules)
|
|
342
|
-
|
|
343
|
-
|
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
from typing import Literal, Any
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import OptimizerModule
|
|
6
|
-
from ...modules import (SGD, LineSearches, NewtonFDM,
|
|
7
|
-
get_line_search, LR, WrapClosure)
|
|
8
|
-
from ...modules.experimental.subspace import Subspace, ProjNormalize, ProjAscentRay
|
|
9
|
-
from ..modular import Modular
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class NewtonFDMRaySearch(Modular):
|
|
13
|
-
"""for experiments, unlikely to work well on most problems.
|
|
14
|
-
|
|
15
|
-
explanation - like a fancy line search, instead of a line searches in a cone using FDM newton."""
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
params,
|
|
19
|
-
lr = 1e-2,
|
|
20
|
-
momentum:float = 0,
|
|
21
|
-
weight_decay:float = 0,
|
|
22
|
-
dampening: float = 0,
|
|
23
|
-
nesterov:bool = False,
|
|
24
|
-
n_rays = 3,
|
|
25
|
-
eps = 1e-2,
|
|
26
|
-
ray_width: float = 1e-1,
|
|
27
|
-
line_search: LineSearches | None = 'brent'
|
|
28
|
-
):
|
|
29
|
-
modules: list[Any] = [
|
|
30
|
-
SGD(momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov),
|
|
31
|
-
LR(lr),
|
|
32
|
-
Subspace(NewtonFDM(eps = eps), ProjNormalize(ProjAscentRay(ray_width, n = n_rays))),
|
|
33
|
-
]
|
|
34
|
-
if lr != 1:
|
|
35
|
-
modules.append(LR(lr))
|
|
36
|
-
|
|
37
|
-
if line_search is not None:
|
|
38
|
-
modules.append(get_line_search(line_search))
|
|
39
|
-
|
|
40
|
-
super().__init__(params, modules)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class LBFGSRaySearch(Modular):
|
|
44
|
-
"""for experiments, unlikely to work well on most problems.
|
|
45
|
-
|
|
46
|
-
explanation - like a fancy line search, instead of a line searches in a cone using LBFGS."""
|
|
47
|
-
def __init__(
|
|
48
|
-
self,
|
|
49
|
-
params,
|
|
50
|
-
lr = 1,
|
|
51
|
-
momentum:float = 0,
|
|
52
|
-
weight_decay:float = 0,
|
|
53
|
-
dampening: float = 0,
|
|
54
|
-
nesterov:bool = False,
|
|
55
|
-
n_rays = 24,
|
|
56
|
-
ray_width: float = 1e-1,
|
|
57
|
-
max_iter: int = 20,
|
|
58
|
-
max_eval: int | None = None,
|
|
59
|
-
tolerance_grad: float = 1e-7,
|
|
60
|
-
tolerance_change: float = 1e-9,
|
|
61
|
-
history_size: int = 100,
|
|
62
|
-
line_search_fn: str | Literal['strong_wolfe'] | None = None,
|
|
63
|
-
):
|
|
64
|
-
lbfgs = WrapClosure(
|
|
65
|
-
torch.optim.LBFGS,
|
|
66
|
-
lr=lr,
|
|
67
|
-
max_iter=max_iter,
|
|
68
|
-
max_eval=max_eval,
|
|
69
|
-
tolerance_grad=tolerance_grad,
|
|
70
|
-
tolerance_change=tolerance_change,
|
|
71
|
-
history_size=history_size,
|
|
72
|
-
line_search_fn=line_search_fn,
|
|
73
|
-
)
|
|
74
|
-
modules: list[OptimizerModule] = [
|
|
75
|
-
SGD(momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov),
|
|
76
|
-
Subspace(lbfgs, ProjNormalize(ProjAscentRay(ray_width, n = n_rays))),
|
|
77
|
-
|
|
78
|
-
]
|
|
79
|
-
|
|
80
|
-
super().__init__(params, modules)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
from .cautious import CautiousAdamW, CautiousLion, CautiousSGD
|
|
2
|
-
from .optimizers import (
|
|
3
|
-
GD,
|
|
4
|
-
SGD,
|
|
5
|
-
Adagrad,
|
|
6
|
-
Adam,
|
|
7
|
-
AdamW,
|
|
8
|
-
Grams,
|
|
9
|
-
LaplacianSmoothingSGD,
|
|
10
|
-
Lion,
|
|
11
|
-
NestedNesterov,
|
|
12
|
-
NoisySGD,
|
|
13
|
-
NormSGD,
|
|
14
|
-
RMSProp,
|
|
15
|
-
Rprop,
|
|
16
|
-
SignSGD,
|
|
17
|
-
)
|
|
18
|
-
from .forward_gradient import ForwardGradient
|
|
@@ -1,158 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
from ...core import OptimizerModule
|
|
5
|
-
from ...modules import Cautious, Adam, SGD, Lion, WeightDecay, LR
|
|
6
|
-
from ..modular import Modular
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class CautiousAdamW(Modular):
|
|
10
|
-
"""Adam, but updates for parameters where update and gradient sign is inconsistent are negated.
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
14
|
-
lr (float): learning rate (default: 1e-3).
|
|
15
|
-
beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
|
|
16
|
-
beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
|
|
17
|
-
eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
|
|
18
|
-
amsgrad (bool, optional):
|
|
19
|
-
whether to use the AMSGrad variant of this algorithm from
|
|
20
|
-
On the Convergence of Adam and Beyond (default: False).
|
|
21
|
-
normalize (bool, optional):
|
|
22
|
-
renormalize update after masking.
|
|
23
|
-
only has effect when mode is 'zero'. Defaults to False.
|
|
24
|
-
c_eps (float, optional): epsilon for normalization after applying cautioning mask. Defaults to 1e-6.
|
|
25
|
-
mode (str, optional):
|
|
26
|
-
what to do with updates with inconsistent signs.
|
|
27
|
-
|
|
28
|
-
"zero" - set them to zero (as in paper)
|
|
29
|
-
|
|
30
|
-
"grad" - set them to the gradient
|
|
31
|
-
|
|
32
|
-
"negate" - negate them (same as using update magnitude and gradient sign).
|
|
33
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
34
|
-
decoupled (bool, optional):
|
|
35
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
36
|
-
"""
|
|
37
|
-
def __init__(
|
|
38
|
-
self,
|
|
39
|
-
params,
|
|
40
|
-
lr: float = 1e-3,
|
|
41
|
-
beta1: float = 0.9,
|
|
42
|
-
beta2: float = 0.999,
|
|
43
|
-
eps: float = 1e-8,
|
|
44
|
-
amsgrad=False,
|
|
45
|
-
normalize = False,
|
|
46
|
-
c_eps = 1e-6,
|
|
47
|
-
mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
|
|
48
|
-
weight_decay: float = 0,
|
|
49
|
-
decoupled=True,
|
|
50
|
-
):
|
|
51
|
-
modules: list[OptimizerModule] = [
|
|
52
|
-
Adam(beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
|
|
53
|
-
LR(lr),
|
|
54
|
-
Cautious(normalize = normalize, eps = c_eps, mode = mode),
|
|
55
|
-
]
|
|
56
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
57
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
58
|
-
super().__init__(params, modules)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class CautiousSGD(Modular):
|
|
62
|
-
"""SGD with momentum, but updates for parameters where update and gradient sign is inconsistent are negated.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
66
|
-
lr (float): learning rate (default: 1e-3).
|
|
67
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
68
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
69
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
70
|
-
nesterov (bool, optional):
|
|
71
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
72
|
-
normalize (bool, optional):
|
|
73
|
-
renormalize update after masking.
|
|
74
|
-
only has effect when mode is 'zero'. Defaults to False.
|
|
75
|
-
c_eps (float, optional): epsilon for normalization after applying cautioning mask. Defaults to 1e-6.
|
|
76
|
-
mode (str, optional):
|
|
77
|
-
what to do with updates with inconsistent signs.
|
|
78
|
-
|
|
79
|
-
"zero" - set them to zero (as in paper)
|
|
80
|
-
|
|
81
|
-
"grad" - set them to the gradient
|
|
82
|
-
|
|
83
|
-
"negate" - negate them (same as using update magnitude and gradient sign).
|
|
84
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
85
|
-
decoupled (bool, optional):
|
|
86
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
87
|
-
"""
|
|
88
|
-
def __init__(
|
|
89
|
-
self,
|
|
90
|
-
params,
|
|
91
|
-
lr: float = 1e-3,
|
|
92
|
-
momentum: float = 0.9,
|
|
93
|
-
dampening: float = 0,
|
|
94
|
-
nesterov: bool = True,
|
|
95
|
-
c_eps = 1e-6,
|
|
96
|
-
normalize = False,
|
|
97
|
-
mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
|
|
98
|
-
weight_decay: float = 0,
|
|
99
|
-
decoupled=True,
|
|
100
|
-
):
|
|
101
|
-
modules: list[OptimizerModule] = [
|
|
102
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
|
|
103
|
-
LR(lr),
|
|
104
|
-
Cautious(normalize = normalize, eps = c_eps, mode = mode),
|
|
105
|
-
]
|
|
106
|
-
|
|
107
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
108
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
109
|
-
|
|
110
|
-
super().__init__(params, modules)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
class CautiousLion(Modular):
|
|
114
|
-
"""Lion optimizer, but updates for parameters where update and gradient sign is inconsistent are negated.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
118
|
-
lr (float): learning rate (default: 1e-3).
|
|
119
|
-
beta1 (float, optional): dampening for momentum. Defaults to 0.9.
|
|
120
|
-
beta2 (float, optional): momentum factor. Defaults to 0.99.
|
|
121
|
-
normalize (bool, optional):
|
|
122
|
-
renormalize update after masking.
|
|
123
|
-
only has effect when mode is 'zero'. Defaults to False.
|
|
124
|
-
c_eps (float, optional): epsilon for normalization after applying cautioning mask. Defaults to 1e-6.
|
|
125
|
-
mode (str, optional):
|
|
126
|
-
what to do with updates with inconsistent signs.
|
|
127
|
-
|
|
128
|
-
"zero" - set them to zero (as in paper)
|
|
129
|
-
|
|
130
|
-
"grad" - set them to the gradient
|
|
131
|
-
|
|
132
|
-
"negate" - negate them (same as using update magnitude and gradient sign).
|
|
133
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
134
|
-
decoupled (bool, optional):
|
|
135
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
136
|
-
"""
|
|
137
|
-
def __init__(
|
|
138
|
-
self,
|
|
139
|
-
params,
|
|
140
|
-
lr: float = 1e-3,
|
|
141
|
-
beta1: float = 0.9,
|
|
142
|
-
beta2: float = 0.99,
|
|
143
|
-
c_eps = 1e-6,
|
|
144
|
-
normalize = False,
|
|
145
|
-
mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
|
|
146
|
-
weight_decay: float = 0,
|
|
147
|
-
decoupled=True,
|
|
148
|
-
):
|
|
149
|
-
modules: list[OptimizerModule] = [
|
|
150
|
-
Lion(beta1, beta2),
|
|
151
|
-
LR(lr),
|
|
152
|
-
Cautious(normalize = normalize, eps = c_eps, mode = mode),
|
|
153
|
-
]
|
|
154
|
-
|
|
155
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
156
|
-
else: modules.insert(0, WeightDecay(weight_decay))
|
|
157
|
-
|
|
158
|
-
super().__init__(params, modules)
|
|
@@ -1,70 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import OptimizationVars, OptimizerModule
|
|
6
|
-
from ...modules import ForwardGradient as _ForwardGradient, SGD, WeightDecay, LR
|
|
7
|
-
from ...tensorlist import Distributions
|
|
8
|
-
from ..modular import Modular
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ForwardGradient(Modular):
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
Evaluates jacobian-vector product with a random vector using forward mode autodiff (torch.func.jvp), which is
|
|
15
|
-
the true directional derivative in the direction of that vector.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
params: iterable of parameters to optimize or dicts defining parameter groups.
|
|
19
|
-
lr (float, optional): learning rate. Defaults to 1e-3.
|
|
20
|
-
n_samples (int): number of forward gradients to evaluate and average.
|
|
21
|
-
distribution (Distributions): distribution for random tangent vector.
|
|
22
|
-
mode (str):
|
|
23
|
-
"jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory.
|
|
24
|
-
|
|
25
|
-
"grad" - evaluates gradient with `loss.backward()` which may be faster but uses all the memory, mainly useful for
|
|
26
|
-
benchmarking as there is probably no point in forward gradient if full gradient is available.
|
|
27
|
-
|
|
28
|
-
"fd" - uses finite difference to estimate JVP, doesn't require gradients to be known. Equivalent to randomized FDM.
|
|
29
|
-
|
|
30
|
-
fd_eps (float, optional): epsilon for finite difference, only has effect if mode is "fd". Defaults to 1e-4.
|
|
31
|
-
momentum (float, optional): momentum. Defaults to 0.
|
|
32
|
-
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
33
|
-
nesterov (bool, optional):
|
|
34
|
-
enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
|
|
35
|
-
weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
|
|
36
|
-
decoupled (bool, optional):
|
|
37
|
-
decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
|
|
38
|
-
|
|
39
|
-
Reference:
|
|
40
|
-
Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022).
|
|
41
|
-
Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
|
|
42
|
-
https://arxiv.org/abs/2202.08587
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
def __init__(
|
|
46
|
-
self,
|
|
47
|
-
params,
|
|
48
|
-
lr: float = 1e-3,
|
|
49
|
-
n_samples: int = 1,
|
|
50
|
-
distribution: Distributions = "normal",
|
|
51
|
-
mode: Literal["jvp", "grad", "fd"] = "jvp",
|
|
52
|
-
fd_eps: float = 1e-4,
|
|
53
|
-
momentum: float = 0,
|
|
54
|
-
dampening: float = 0,
|
|
55
|
-
nesterov: bool = False,
|
|
56
|
-
weight_decay: float = 0,
|
|
57
|
-
decoupled=False,
|
|
58
|
-
):
|
|
59
|
-
modules: list = [
|
|
60
|
-
_ForwardGradient(
|
|
61
|
-
n_samples=n_samples,
|
|
62
|
-
distribution=distribution,
|
|
63
|
-
mode=mode,
|
|
64
|
-
fd_eps=fd_eps,
|
|
65
|
-
),
|
|
66
|
-
SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
|
|
67
|
-
LR(lr),
|
|
68
|
-
]
|
|
69
|
-
if decoupled: modules.append(WeightDecay(weight_decay))
|
|
70
|
-
super().__init__(params, modules)
|