torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_identical.py
CHANGED
|
@@ -97,30 +97,30 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
|
|
|
97
97
|
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
98
98
|
def test_adam(amsgrad):
|
|
99
99
|
torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
|
|
100
|
-
tz_fn = lambda p: tz.
|
|
101
|
-
tz_fn2 = lambda p: tz.
|
|
102
|
-
tz_fn3 = lambda p: tz.
|
|
103
|
-
tz_fn4 = lambda p: tz.
|
|
104
|
-
tz_fn5 = lambda p: tz.
|
|
105
|
-
tz_fn_ops = lambda p: tz.
|
|
100
|
+
tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad))
|
|
101
|
+
tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
|
|
102
|
+
tz_fn3 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
|
|
103
|
+
tz_fn4 = lambda p: tz.Optimizer(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
|
|
104
|
+
tz_fn5 = lambda p: tz.Optimizer(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
|
|
105
|
+
tz_fn_ops = lambda p: tz.Optimizer(
|
|
106
106
|
p,
|
|
107
107
|
tz.m.DivModules(
|
|
108
108
|
tz.m.EMA(0.9, debiased=True),
|
|
109
109
|
[tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
|
|
110
110
|
))
|
|
111
|
-
tz_fn_ops2 = lambda p: tz.
|
|
111
|
+
tz_fn_ops2 = lambda p: tz.Optimizer(
|
|
112
112
|
p,
|
|
113
113
|
tz.m.DivModules(
|
|
114
114
|
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
|
|
115
115
|
[tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Debias2(beta=0.999), tz.m.Add(1e-8)]
|
|
116
116
|
))
|
|
117
|
-
tz_fn_ops3 = lambda p: tz.
|
|
117
|
+
tz_fn_ops3 = lambda p: tz.Optimizer(
|
|
118
118
|
p,
|
|
119
119
|
tz.m.DivModules(
|
|
120
120
|
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
|
|
121
121
|
[tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
|
|
122
122
|
))
|
|
123
|
-
tz_fn_ops4 = lambda p: tz.
|
|
123
|
+
tz_fn_ops4 = lambda p: tz.Optimizer(
|
|
124
124
|
p,
|
|
125
125
|
tz.m.DivModules(
|
|
126
126
|
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
|
|
@@ -145,19 +145,19 @@ def test_adam(amsgrad):
|
|
|
145
145
|
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
146
146
|
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
147
147
|
def test_adam_hyperparams(beta1, beta2, eps, amsgrad, lr):
|
|
148
|
-
tz_fn = lambda p: tz.
|
|
149
|
-
tz_fn2 = lambda p: tz.
|
|
148
|
+
tz_fn = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
|
|
149
|
+
tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
|
|
150
150
|
_assert_identical_opts([tz_fn, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
|
|
151
151
|
|
|
152
152
|
@pytest.mark.parametrize('centered', [True, False])
|
|
153
153
|
def test_rmsprop(centered):
|
|
154
154
|
torch_fn = lambda p: torch.optim.RMSprop(p, 1, centered=centered)
|
|
155
|
-
tz_fn = lambda p: tz.
|
|
156
|
-
tz_fn2 = lambda p: tz.
|
|
155
|
+
tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(centered=centered, init='zeros'))
|
|
156
|
+
tz_fn2 = lambda p: tz.Optimizer(
|
|
157
157
|
p,
|
|
158
158
|
tz.m.Div([tz.m.CenteredSqrtEMASquared(0.99) if centered else tz.m.SqrtEMASquared(0.99), tz.m.Add(1e-8)]),
|
|
159
159
|
)
|
|
160
|
-
tz_fn3 = lambda p: tz.
|
|
160
|
+
tz_fn3 = lambda p: tz.Optimizer(
|
|
161
161
|
p,
|
|
162
162
|
tz.m.Div([tz.m.CenteredEMASquared(0.99) if centered else tz.m.EMASquared(0.99), tz.m.Sqrt(), tz.m.Add(1e-8)]),
|
|
163
163
|
)
|
|
@@ -173,7 +173,7 @@ def test_rmsprop(centered):
|
|
|
173
173
|
@pytest.mark.parametrize('centered', [True, False])
|
|
174
174
|
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
175
175
|
def test_rmsprop_hyperparams(beta, eps, centered, lr):
|
|
176
|
-
tz_fn = lambda p: tz.
|
|
176
|
+
tz_fn = lambda p: tz.Optimizer(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
|
|
177
177
|
torch_fn = lambda p: torch.optim.RMSprop(p, lr, beta, eps=eps, centered=centered)
|
|
178
178
|
_assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=10)
|
|
179
179
|
|
|
@@ -185,7 +185,7 @@ def test_rmsprop_hyperparams(beta, eps, centered, lr):
|
|
|
185
185
|
@pytest.mark.parametrize('ub', [50, 1.5])
|
|
186
186
|
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
187
187
|
def test_rprop(nplus, nminus, lb, ub, lr):
|
|
188
|
-
tz_fn = lambda p: tz.
|
|
188
|
+
tz_fn = lambda p: tz.Optimizer(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
|
|
189
189
|
torch_fn = lambda p: torch.optim.Rprop(p, lr, (nminus, nplus), (lb, ub))
|
|
190
190
|
_assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=30)
|
|
191
191
|
_assert_identical_merge_closure(tz_fn, 'cpu', 30)
|
|
@@ -193,8 +193,8 @@ def test_rprop(nplus, nminus, lb, ub, lr):
|
|
|
193
193
|
|
|
194
194
|
def test_adagrad():
|
|
195
195
|
torch_fn = lambda p: torch.optim.Adagrad(p, 1)
|
|
196
|
-
tz_fn = lambda p: tz.
|
|
197
|
-
tz_fn2 = lambda p: tz.
|
|
196
|
+
tz_fn = lambda p: tz.Optimizer(p, tz.m.Adagrad(), tz.m.LR(1))
|
|
197
|
+
tz_fn2 = lambda p: tz.Optimizer(
|
|
198
198
|
p,
|
|
199
199
|
tz.m.Div([tz.m.Pow(2), tz.m.AccumulateSum(), tz.m.Sqrt(), tz.m.Add(1e-10)]),
|
|
200
200
|
)
|
|
@@ -212,15 +212,15 @@ def test_adagrad():
|
|
|
212
212
|
@pytest.mark.parametrize('lr', [0.1, 1])
|
|
213
213
|
def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
|
|
214
214
|
torch_fn = lambda p: torch.optim.Adagrad(p, lr, initial_accumulator_value=initial_accumulator_value, eps=eps)
|
|
215
|
-
tz_fn1 = lambda p: tz.
|
|
216
|
-
tz_fn2 = lambda p: tz.
|
|
215
|
+
tz_fn1 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
|
|
216
|
+
tz_fn2 = lambda p: tz.Optimizer(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
|
|
217
217
|
_assert_identical_opts([torch_fn, tz_fn1, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
|
|
218
218
|
|
|
219
219
|
|
|
220
220
|
@pytest.mark.parametrize('tensorwise', [True, False])
|
|
221
221
|
def test_graft(tensorwise):
|
|
222
|
-
graft1 = lambda p: tz.
|
|
223
|
-
graft2 = lambda p: tz.
|
|
222
|
+
graft1 = lambda p: tz.Optimizer(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
|
|
223
|
+
graft2 = lambda p: tz.Optimizer(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
|
|
224
224
|
_assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
|
|
225
225
|
for fn in [graft1, graft2]:
|
|
226
226
|
if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)
|