torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
docs/source/conf.py
CHANGED
|
@@ -6,10 +6,10 @@
|
|
|
6
6
|
# -- Project information -----------------------------------------------------
|
|
7
7
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
8
8
|
import sys, os
|
|
9
|
-
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
9
|
+
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
10
10
|
|
|
11
11
|
project = 'torchzero'
|
|
12
|
-
copyright = '
|
|
12
|
+
copyright = '2025, Ivan Nikishev'
|
|
13
13
|
author = 'Ivan Nikishev'
|
|
14
14
|
|
|
15
15
|
# -- General configuration ---------------------------------------------------
|
|
@@ -24,10 +24,12 @@ extensions = [
|
|
|
24
24
|
'sphinx.ext.githubpages',
|
|
25
25
|
'sphinx.ext.napoleon',
|
|
26
26
|
'autoapi.extension',
|
|
27
|
+
"myst_nb",
|
|
28
|
+
|
|
27
29
|
# 'sphinx_rtd_theme',
|
|
28
30
|
]
|
|
29
31
|
autosummary_generate = True
|
|
30
|
-
autoapi_dirs = ['../../
|
|
32
|
+
autoapi_dirs = ['../../torchzero']
|
|
31
33
|
autoapi_type = "python"
|
|
32
34
|
# autoapi_ignore = ["*/tensorlist.py"]
|
|
33
35
|
|
|
@@ -48,7 +50,7 @@ exclude_patterns = []
|
|
|
48
50
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
|
49
51
|
|
|
50
52
|
#html_theme = 'alabaster'
|
|
51
|
-
html_theme = '
|
|
53
|
+
html_theme = 'sphinx_rtd_theme'
|
|
52
54
|
html_static_path = ['_static']
|
|
53
55
|
|
|
54
56
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
class MyModule:
|
|
2
|
+
"""[One-line summary of the class].
|
|
3
|
+
|
|
4
|
+
[A more detailed description of the class, explaining its purpose, how it
|
|
5
|
+
works, and its typical use cases. You can use multiple paragraphs.]
|
|
6
|
+
|
|
7
|
+
.. note::
|
|
8
|
+
[Optional: Add important notes, warnings, or usage guidelines here.
|
|
9
|
+
For example, you could mention if a closure is required, discuss
|
|
10
|
+
stability, or highlight performance characteristics. Use the `.. note::`
|
|
11
|
+
directive to make it stand out in the documentation.]
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
param1 (type, optional):
|
|
15
|
+
[Description of the first parameter. Use :code:`backticks` for
|
|
16
|
+
inline code like variable names or specific values like ``"autograd"``.
|
|
17
|
+
Explain what the parameter does.] Defaults to [value].
|
|
18
|
+
param2 (type):
|
|
19
|
+
[Description of a mandatory parameter (no "optional" or "Defaults to").]
|
|
20
|
+
**kwargs:
|
|
21
|
+
[If you accept keyword arguments, describe what they are used for.]
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
[A title or short sentence describing the first example]:
|
|
25
|
+
|
|
26
|
+
.. code-block:: python
|
|
27
|
+
|
|
28
|
+
opt = tz.Modular(
|
|
29
|
+
model.parameters(),
|
|
30
|
+
...
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
[A title or short sentence for a second, different example]:
|
|
34
|
+
|
|
35
|
+
.. code-block:: python
|
|
36
|
+
|
|
37
|
+
opt = tz.Modular(
|
|
38
|
+
model.parameters(),
|
|
39
|
+
...
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
References:
|
|
43
|
+
- [Optional: A citation for a relevant paper, book, or algorithm.]
|
|
44
|
+
- [Optional: A link to a blog post or website with more information.]
|
|
45
|
+
|
|
46
|
+
"""
|
tests/test_identical.py
CHANGED
|
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
|
|
|
96
96
|
|
|
97
97
|
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
98
98
|
def test_adam(amsgrad):
|
|
99
|
-
|
|
100
|
-
# pytorch applies debiasing separately so it is applied before epsilo
|
|
99
|
+
torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
|
|
101
100
|
tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
|
|
102
101
|
tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
|
|
103
102
|
tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
|
|
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
|
|
|
133
132
|
tz.m.Debias2(beta=0.999),
|
|
134
133
|
tz.m.Add(1e-8)]
|
|
135
134
|
))
|
|
136
|
-
tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
135
|
+
tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
137
136
|
|
|
138
137
|
_assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
|
|
139
138
|
for fn in tz_fns:
|
tests/test_opts.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Sanity tests to make sure everything works.
|
|
3
|
+
|
|
4
|
+
This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
|
|
5
|
+
don't error or become unhinged with different parameter shapes.
|
|
6
|
+
"""
|
|
2
7
|
from collections.abc import Callable
|
|
3
8
|
from functools import partial
|
|
4
9
|
|
|
@@ -68,6 +73,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
68
73
|
assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
|
|
69
74
|
losses.append(loss)
|
|
70
75
|
|
|
76
|
+
losses.append(objective())
|
|
71
77
|
return torch.stack(losses).nan_to_num(0,10000,10000).min()
|
|
72
78
|
|
|
73
79
|
def _run_func(opt_fn: Callable, func:str, merge: bool, use_closure: bool, steps: int):
|
|
@@ -286,42 +292,42 @@ FDM_central2 = Run(
|
|
|
286
292
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
|
|
287
293
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
|
|
288
294
|
needs_closure=True,
|
|
289
|
-
func='booth', steps=50, loss=1e-
|
|
295
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
290
296
|
sphere_steps=2, sphere_loss=340,
|
|
291
297
|
)
|
|
292
298
|
FDM_forward2 = Run(
|
|
293
299
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
|
|
294
300
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
|
|
295
301
|
needs_closure=True,
|
|
296
|
-
func='booth', steps=50, loss=1e-
|
|
302
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
297
303
|
sphere_steps=2, sphere_loss=340,
|
|
298
304
|
)
|
|
299
305
|
FDM_backward2 = Run(
|
|
300
306
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
|
|
301
307
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
|
|
302
308
|
needs_closure=True,
|
|
303
|
-
func='booth', steps=50, loss=
|
|
309
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
304
310
|
sphere_steps=2, sphere_loss=340,
|
|
305
311
|
)
|
|
306
312
|
FDM_forward3 = Run(
|
|
307
313
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
|
|
308
314
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
|
|
309
315
|
needs_closure=True,
|
|
310
|
-
func='booth', steps=50, loss=
|
|
316
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
311
317
|
sphere_steps=2, sphere_loss=340,
|
|
312
318
|
)
|
|
313
319
|
FDM_backward3 = Run(
|
|
314
320
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
|
|
315
321
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
|
|
316
322
|
needs_closure=True,
|
|
317
|
-
func='booth', steps=50, loss=
|
|
323
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
318
324
|
sphere_steps=2, sphere_loss=340,
|
|
319
325
|
)
|
|
320
326
|
FDM_central4 = Run(
|
|
321
327
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
|
|
322
328
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
|
|
323
329
|
needs_closure=True,
|
|
324
|
-
func='booth', steps=50, loss=
|
|
330
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
325
331
|
sphere_steps=2, sphere_loss=340,
|
|
326
332
|
)
|
|
327
333
|
|
|
@@ -460,8 +466,8 @@ AdaptiveBacktracking = Run(
|
|
|
460
466
|
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
461
467
|
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
462
468
|
needs_closure=True,
|
|
463
|
-
func='booth', steps=50, loss=
|
|
464
|
-
sphere_steps=2, sphere_loss=
|
|
469
|
+
func='booth', steps=50, loss=1e-12, merge_invariant=True,
|
|
470
|
+
sphere_steps=2, sphere_loss=1e-10,
|
|
465
471
|
)
|
|
466
472
|
AdaptiveBacktracking_try_negative = Run(
|
|
467
473
|
func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
|
|
@@ -488,15 +494,6 @@ StrongWolfe = Run(
|
|
|
488
494
|
sphere_steps=2, sphere_loss=0,
|
|
489
495
|
)
|
|
490
496
|
|
|
491
|
-
# ------------------------- line_search/trust_region ------------------------- #
|
|
492
|
-
TrustRegion = Run(
|
|
493
|
-
func_opt=lambda p: tz.Modular(p, tz.m.TrustRegion()),
|
|
494
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.TrustRegion(init=0.1)),
|
|
495
|
-
needs_closure=True,
|
|
496
|
-
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
497
|
-
sphere_steps=10, sphere_loss=1e-5,
|
|
498
|
-
)
|
|
499
|
-
|
|
500
497
|
# ----------------------------------- lr/lr ---------------------------------- #
|
|
501
498
|
LR = Run(
|
|
502
499
|
func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
|
|
@@ -524,7 +521,7 @@ PolyakStepSize = Run(
|
|
|
524
521
|
func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
|
|
525
522
|
sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
|
|
526
523
|
needs_closure=True,
|
|
527
|
-
func='booth', steps=50, loss=1e-
|
|
524
|
+
func='booth', steps=50, loss=1e-7, merge_invariant=True,
|
|
528
525
|
sphere_steps=10, sphere_loss=0.002,
|
|
529
526
|
)
|
|
530
527
|
RandomStepSize = Run(
|
|
@@ -581,8 +578,8 @@ UpdateGradientSignConsistency = Run(
|
|
|
581
578
|
sphere_steps=10, sphere_loss=2,
|
|
582
579
|
)
|
|
583
580
|
IntermoduleCautious = Run(
|
|
584
|
-
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.01)),
|
|
585
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.1)),
|
|
581
|
+
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.01)),
|
|
582
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.1)),
|
|
586
583
|
needs_closure=False,
|
|
587
584
|
func='booth', steps=50, loss=1e-4, merge_invariant=True,
|
|
588
585
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -595,8 +592,8 @@ ScaleByGradCosineSimilarity = Run(
|
|
|
595
592
|
sphere_steps=10, sphere_loss=0.1,
|
|
596
593
|
)
|
|
597
594
|
ScaleModulesByCosineSimilarity = Run(
|
|
598
|
-
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.05)),
|
|
599
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.1)),
|
|
595
|
+
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.05)),
|
|
596
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.1)),
|
|
600
597
|
needs_closure=False,
|
|
601
598
|
func='booth', steps=50, loss=0.005, merge_invariant=True,
|
|
602
599
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -604,44 +601,44 @@ ScaleModulesByCosineSimilarity = Run(
|
|
|
604
601
|
|
|
605
602
|
# ------------------------- momentum/matrix_momentum ------------------------- #
|
|
606
603
|
MatrixMomentum_forward = Run(
|
|
607
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
608
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
604
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
|
|
605
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
|
|
609
606
|
needs_closure=True,
|
|
610
607
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
611
608
|
sphere_steps=10, sphere_loss=0,
|
|
612
609
|
)
|
|
613
610
|
MatrixMomentum_forward = Run(
|
|
614
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
615
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
611
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
|
|
612
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
|
|
616
613
|
needs_closure=True,
|
|
617
614
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
618
615
|
sphere_steps=10, sphere_loss=0,
|
|
619
616
|
)
|
|
620
617
|
MatrixMomentum_forward = Run(
|
|
621
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
622
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
618
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
|
|
619
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
|
|
623
620
|
needs_closure=True,
|
|
624
621
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
625
622
|
sphere_steps=10, sphere_loss=0,
|
|
626
623
|
)
|
|
627
624
|
|
|
628
625
|
AdaptiveMatrixMomentum_forward = Run(
|
|
629
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
630
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
626
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
|
|
627
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
|
|
631
628
|
needs_closure=True,
|
|
632
629
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
633
630
|
sphere_steps=10, sphere_loss=0,
|
|
634
631
|
)
|
|
635
632
|
AdaptiveMatrixMomentum_central = Run(
|
|
636
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
637
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
633
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
|
|
634
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
|
|
638
635
|
needs_closure=True,
|
|
639
636
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
640
637
|
sphere_steps=10, sphere_loss=0,
|
|
641
638
|
)
|
|
642
639
|
AdaptiveMatrixMomentum_autograd = Run(
|
|
643
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
644
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
640
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
|
|
641
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
|
|
645
642
|
needs_closure=True,
|
|
646
643
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
647
644
|
sphere_steps=10, sphere_loss=0,
|
|
@@ -678,8 +675,8 @@ GradAccumulation = Run(
|
|
|
678
675
|
sphere_steps=20, sphere_loss=1e-11,
|
|
679
676
|
)
|
|
680
677
|
NegateOnLossIncrease = Run(
|
|
681
|
-
func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(),),
|
|
682
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(),),
|
|
678
|
+
func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
|
|
679
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
|
|
683
680
|
needs_closure=True,
|
|
684
681
|
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
685
682
|
sphere_steps=20, sphere_loss=0.001,
|
|
@@ -687,7 +684,7 @@ NegateOnLossIncrease = Run(
|
|
|
687
684
|
# -------------------------------- misc/switch ------------------------------- #
|
|
688
685
|
Alternate = Run(
|
|
689
686
|
func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
|
|
690
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
|
|
687
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
|
|
691
688
|
needs_closure=False,
|
|
692
689
|
func='booth', steps=50, loss=1, merge_invariant=True,
|
|
693
690
|
sphere_steps=20, sphere_loss=20,
|
|
@@ -719,33 +716,33 @@ Lion = Run(
|
|
|
719
716
|
)
|
|
720
717
|
# ---------------------------- optimizers/shampoo ---------------------------- #
|
|
721
718
|
Shampoo = Run(
|
|
722
|
-
func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(
|
|
723
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.
|
|
719
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
|
|
720
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
|
|
724
721
|
needs_closure=False,
|
|
725
|
-
func='booth', steps=50, loss=
|
|
726
|
-
sphere_steps=20, sphere_loss=
|
|
722
|
+
func='booth', steps=50, loss=0.02, merge_invariant=False,
|
|
723
|
+
sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
|
|
727
724
|
)
|
|
728
725
|
|
|
729
726
|
# ------------------------- quasi_newton/quasi_newton ------------------------ #
|
|
730
727
|
BFGS = Run(
|
|
731
|
-
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
732
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
728
|
+
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
729
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
733
730
|
needs_closure=True,
|
|
734
|
-
func='rosen', steps=50, loss=
|
|
735
|
-
sphere_steps=10, sphere_loss=
|
|
731
|
+
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
732
|
+
sphere_steps=10, sphere_loss=1e-10,
|
|
736
733
|
)
|
|
737
734
|
SR1 = Run(
|
|
738
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
|
|
739
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
|
|
735
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
|
|
736
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
|
|
740
737
|
needs_closure=True,
|
|
741
738
|
func='rosen', steps=50, loss=1e-12, merge_invariant=True,
|
|
742
739
|
sphere_steps=10, sphere_loss=0,
|
|
743
740
|
)
|
|
744
741
|
SSVM = Run(
|
|
745
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
746
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
742
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
|
|
743
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
|
|
747
744
|
needs_closure=True,
|
|
748
|
-
func='rosen', steps=50, loss=
|
|
745
|
+
func='rosen', steps=50, loss=0.5, merge_invariant=True,
|
|
749
746
|
sphere_steps=10, sphere_loss=0,
|
|
750
747
|
)
|
|
751
748
|
|
|
@@ -760,21 +757,21 @@ LBFGS = Run(
|
|
|
760
757
|
|
|
761
758
|
# ----------------------------- quasi_newton/lsr1 ---------------------------- #
|
|
762
759
|
LSR1 = Run(
|
|
763
|
-
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
|
|
764
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
|
|
760
|
+
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
|
|
761
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
|
|
765
762
|
needs_closure=True,
|
|
766
763
|
func='rosen', steps=50, loss=0, merge_invariant=True,
|
|
767
764
|
sphere_steps=10, sphere_loss=0,
|
|
768
765
|
)
|
|
769
766
|
|
|
770
|
-
# ---------------------------- quasi_newton/olbfgs --------------------------- #
|
|
771
|
-
OnlineLBFGS = Run(
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
)
|
|
767
|
+
# # ---------------------------- quasi_newton/olbfgs --------------------------- #
|
|
768
|
+
# OnlineLBFGS = Run(
|
|
769
|
+
# func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
|
|
770
|
+
# sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
|
|
771
|
+
# needs_closure=True,
|
|
772
|
+
# func='rosen', steps=50, loss=0, merge_invariant=True,
|
|
773
|
+
# sphere_steps=10, sphere_loss=0,
|
|
774
|
+
# )
|
|
778
775
|
|
|
779
776
|
# ---------------------------- second_order/newton --------------------------- #
|
|
780
777
|
Newton = Run(
|
|
@@ -791,13 +788,13 @@ NewtonCG = Run(
|
|
|
791
788
|
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
|
|
792
789
|
needs_closure=True,
|
|
793
790
|
func='rosen', steps=20, loss=1e-7, merge_invariant=True,
|
|
794
|
-
sphere_steps=2, sphere_loss=
|
|
791
|
+
sphere_steps=2, sphere_loss=3e-4,
|
|
795
792
|
)
|
|
796
793
|
|
|
797
794
|
# ---------------------------- smoothing/gaussian ---------------------------- #
|
|
798
795
|
GaussianHomotopy = Run(
|
|
799
|
-
func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
800
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
796
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
797
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
801
798
|
needs_closure=True,
|
|
802
799
|
func='booth', steps=20, loss=0.1, merge_invariant=True,
|
|
803
800
|
sphere_steps=10, sphere_loss=200,
|
|
@@ -854,8 +851,35 @@ SophiaH = Run(
|
|
|
854
851
|
sphere_steps=10, sphere_loss=40,
|
|
855
852
|
)
|
|
856
853
|
|
|
854
|
+
# -------------------------- higher_order ------------------------- #
|
|
855
|
+
HigherOrderNewton = Run(
|
|
856
|
+
func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
|
|
857
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
|
|
858
|
+
needs_closure=True,
|
|
859
|
+
func='rosen', steps=1, loss=2e-10, merge_invariant=True,
|
|
860
|
+
sphere_steps=1, sphere_loss=1e-10,
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
# ---------------------------- optimizers/ladagrad --------------------------- #
|
|
864
|
+
LMAdagrad = Run(
|
|
865
|
+
func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
|
|
866
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
|
|
867
|
+
needs_closure=False,
|
|
868
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
869
|
+
sphere_steps=20, sphere_loss=1e-9,
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
# ------------------------------ optimizers/adan ----------------------------- #
|
|
873
|
+
Adan = Run(
|
|
874
|
+
func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
|
|
875
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
|
|
876
|
+
needs_closure=False,
|
|
877
|
+
func='booth', steps=50, loss=60, merge_invariant=True,
|
|
878
|
+
sphere_steps=20, sphere_loss=60,
|
|
879
|
+
)
|
|
880
|
+
|
|
857
881
|
# ------------------------------------ CGs ----------------------------------- #
|
|
858
|
-
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY):
|
|
882
|
+
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
|
|
859
883
|
for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
|
|
860
884
|
# but also test 10 to make sure it doesn't explode after converging
|
|
861
885
|
Run(
|
|
@@ -868,10 +892,33 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
|
|
|
868
892
|
|
|
869
893
|
# ------------------------------- QN stability ------------------------------- #
|
|
870
894
|
# stability test
|
|
871
|
-
for QN in (
|
|
895
|
+
for QN in (
|
|
896
|
+
tz.m.BFGS,
|
|
897
|
+
partial(tz.m.BFGS, inverse=False),
|
|
898
|
+
tz.m.SR1,
|
|
899
|
+
partial(tz.m.SR1, inverse=False),
|
|
900
|
+
tz.m.DFP,
|
|
901
|
+
partial(tz.m.DFP, inverse=False),
|
|
902
|
+
tz.m.BroydenGood,
|
|
903
|
+
partial(tz.m.BroydenGood, inverse=False),
|
|
904
|
+
tz.m.BroydenBad,
|
|
905
|
+
partial(tz.m.BroydenBad, inverse=False),
|
|
906
|
+
tz.m.Greenstadt1,
|
|
907
|
+
tz.m.Greenstadt2,
|
|
908
|
+
tz.m.ICUM,
|
|
909
|
+
tz.m.ThomasOptimalMethod,
|
|
910
|
+
tz.m.FletcherVMM,
|
|
911
|
+
tz.m.Horisho,
|
|
912
|
+
partial(tz.m.Horisho, inner=tz.m.GradientCorrection()),
|
|
913
|
+
tz.m.Pearson,
|
|
914
|
+
tz.m.ProjectedNewtonRaphson,
|
|
915
|
+
tz.m.PSB,
|
|
916
|
+
tz.m.McCormick,
|
|
917
|
+
tz.m.SSVM,
|
|
918
|
+
):
|
|
872
919
|
Run(
|
|
873
|
-
func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
874
|
-
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
920
|
+
func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
|
|
921
|
+
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
|
|
875
922
|
needs_closure=True,
|
|
876
923
|
func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
|
|
877
924
|
sphere_steps=10, sphere_loss=1e-20,
|
tests/test_tensorlist.py
CHANGED
|
@@ -1261,8 +1261,8 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
|
|
|
1261
1261
|
elif reduction_method == 'quantile': expected = vec.quantile(q)
|
|
1262
1262
|
else:
|
|
1263
1263
|
pytest.fail("Unknown global reduction")
|
|
1264
|
-
assert False,
|
|
1265
|
-
assert torch.allclose(result, expected)
|
|
1264
|
+
assert False, reduction_method
|
|
1265
|
+
assert torch.allclose(result, expected, atol=1e-4)
|
|
1266
1266
|
else:
|
|
1267
1267
|
expected_list = []
|
|
1268
1268
|
for t in simple_tl:
|