torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
docs/source/conf.py
CHANGED
|
@@ -6,10 +6,10 @@
|
|
|
6
6
|
# -- Project information -----------------------------------------------------
|
|
7
7
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
8
8
|
import sys, os
|
|
9
|
-
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
9
|
+
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
10
10
|
|
|
11
11
|
project = 'torchzero'
|
|
12
|
-
copyright = '
|
|
12
|
+
copyright = '2025, Ivan Nikishev'
|
|
13
13
|
author = 'Ivan Nikishev'
|
|
14
14
|
|
|
15
15
|
# -- General configuration ---------------------------------------------------
|
|
@@ -24,10 +24,12 @@ extensions = [
|
|
|
24
24
|
'sphinx.ext.githubpages',
|
|
25
25
|
'sphinx.ext.napoleon',
|
|
26
26
|
'autoapi.extension',
|
|
27
|
+
"myst_nb",
|
|
28
|
+
|
|
27
29
|
# 'sphinx_rtd_theme',
|
|
28
30
|
]
|
|
29
31
|
autosummary_generate = True
|
|
30
|
-
autoapi_dirs = ['../../
|
|
32
|
+
autoapi_dirs = ['../../torchzero']
|
|
31
33
|
autoapi_type = "python"
|
|
32
34
|
# autoapi_ignore = ["*/tensorlist.py"]
|
|
33
35
|
|
|
@@ -48,7 +50,7 @@ exclude_patterns = []
|
|
|
48
50
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
|
49
51
|
|
|
50
52
|
#html_theme = 'alabaster'
|
|
51
|
-
html_theme = '
|
|
53
|
+
html_theme = 'sphinx_rtd_theme'
|
|
52
54
|
html_static_path = ['_static']
|
|
53
55
|
|
|
54
56
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
class MyModule:
|
|
2
|
+
"""[One-line summary of the class].
|
|
3
|
+
|
|
4
|
+
[A more detailed description of the class, explaining its purpose, how it
|
|
5
|
+
works, and its typical use cases. You can use multiple paragraphs.]
|
|
6
|
+
|
|
7
|
+
.. note::
|
|
8
|
+
[Optional: Add important notes, warnings, or usage guidelines here.
|
|
9
|
+
For example, you could mention if a closure is required, discuss
|
|
10
|
+
stability, or highlight performance characteristics. Use the `.. note::`
|
|
11
|
+
directive to make it stand out in the documentation.]
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
param1 (type, optional):
|
|
15
|
+
[Description of the first parameter. Use :code:`backticks` for
|
|
16
|
+
inline code like variable names or specific values like ``"autograd"``.
|
|
17
|
+
Explain what the parameter does.] Defaults to [value].
|
|
18
|
+
param2 (type):
|
|
19
|
+
[Description of a mandatory parameter (no "optional" or "Defaults to").]
|
|
20
|
+
**kwargs:
|
|
21
|
+
[If you accept keyword arguments, describe what they are used for.]
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
[A title or short sentence describing the first example]:
|
|
25
|
+
|
|
26
|
+
.. code-block:: python
|
|
27
|
+
|
|
28
|
+
opt = tz.Modular(
|
|
29
|
+
model.parameters(),
|
|
30
|
+
...
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
[A title or short sentence for a second, different example]:
|
|
34
|
+
|
|
35
|
+
.. code-block:: python
|
|
36
|
+
|
|
37
|
+
opt = tz.Modular(
|
|
38
|
+
model.parameters(),
|
|
39
|
+
...
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
References:
|
|
43
|
+
- [Optional: A citation for a relevant paper, book, or algorithm.]
|
|
44
|
+
- [Optional: A link to a blog post or website with more information.]
|
|
45
|
+
|
|
46
|
+
"""
|
tests/test_identical.py
CHANGED
|
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
|
|
|
96
96
|
|
|
97
97
|
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
98
98
|
def test_adam(amsgrad):
|
|
99
|
-
|
|
100
|
-
# pytorch applies debiasing separately so it is applied before epsilo
|
|
99
|
+
torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
|
|
101
100
|
tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
|
|
102
101
|
tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
|
|
103
102
|
tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
|
|
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
|
|
|
133
132
|
tz.m.Debias2(beta=0.999),
|
|
134
133
|
tz.m.Add(1e-8)]
|
|
135
134
|
))
|
|
136
|
-
tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
135
|
+
tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
137
136
|
|
|
138
137
|
_assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
|
|
139
138
|
for fn in tz_fns:
|
tests/test_opts.py
CHANGED
|
@@ -292,42 +292,42 @@ FDM_central2 = Run(
|
|
|
292
292
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central2'), tz.m.LR(0.1)),
|
|
293
293
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(), tz.m.LR(0.1)),
|
|
294
294
|
needs_closure=True,
|
|
295
|
-
func='booth', steps=50, loss=1e-
|
|
295
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
296
296
|
sphere_steps=2, sphere_loss=340,
|
|
297
297
|
)
|
|
298
298
|
FDM_forward2 = Run(
|
|
299
299
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
|
|
300
300
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward2'), tz.m.LR(0.1)),
|
|
301
301
|
needs_closure=True,
|
|
302
|
-
func='booth', steps=50, loss=1e-
|
|
302
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
303
303
|
sphere_steps=2, sphere_loss=340,
|
|
304
304
|
)
|
|
305
305
|
FDM_backward2 = Run(
|
|
306
306
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
|
|
307
307
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward2'), tz.m.LR(0.1)),
|
|
308
308
|
needs_closure=True,
|
|
309
|
-
func='booth', steps=50, loss=
|
|
309
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
310
310
|
sphere_steps=2, sphere_loss=340,
|
|
311
311
|
)
|
|
312
312
|
FDM_forward3 = Run(
|
|
313
313
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
|
|
314
314
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='forward3'), tz.m.LR(0.1)),
|
|
315
315
|
needs_closure=True,
|
|
316
|
-
func='booth', steps=50, loss=
|
|
316
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
317
317
|
sphere_steps=2, sphere_loss=340,
|
|
318
318
|
)
|
|
319
319
|
FDM_backward3 = Run(
|
|
320
320
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
|
|
321
321
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='backward3'), tz.m.LR(0.1)),
|
|
322
322
|
needs_closure=True,
|
|
323
|
-
func='booth', steps=50, loss=
|
|
323
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
324
324
|
sphere_steps=2, sphere_loss=340,
|
|
325
325
|
)
|
|
326
326
|
FDM_central4 = Run(
|
|
327
327
|
func_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
|
|
328
328
|
sphere_opt=lambda p: tz.Modular(p, tz.m.FDM(formula='central4'), tz.m.LR(0.1)),
|
|
329
329
|
needs_closure=True,
|
|
330
|
-
func='booth', steps=50, loss=
|
|
330
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
331
331
|
sphere_steps=2, sphere_loss=340,
|
|
332
332
|
)
|
|
333
333
|
|
|
@@ -466,8 +466,8 @@ AdaptiveBacktracking = Run(
|
|
|
466
466
|
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
467
467
|
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveBacktracking()),
|
|
468
468
|
needs_closure=True,
|
|
469
|
-
func='booth', steps=50, loss=
|
|
470
|
-
sphere_steps=2, sphere_loss=
|
|
469
|
+
func='booth', steps=50, loss=1e-12, merge_invariant=True,
|
|
470
|
+
sphere_steps=2, sphere_loss=1e-10,
|
|
471
471
|
)
|
|
472
472
|
AdaptiveBacktracking_try_negative = Run(
|
|
473
473
|
func_opt=lambda p: tz.Modular(p, tz.m.Mul(-1), tz.m.AdaptiveBacktracking(try_negative=True)),
|
|
@@ -494,15 +494,6 @@ StrongWolfe = Run(
|
|
|
494
494
|
sphere_steps=2, sphere_loss=0,
|
|
495
495
|
)
|
|
496
496
|
|
|
497
|
-
# ------------------------- line_search/trust_region ------------------------- #
|
|
498
|
-
TrustRegion = Run(
|
|
499
|
-
func_opt=lambda p: tz.Modular(p, tz.m.TrustRegion()),
|
|
500
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.TrustRegion(init=0.1)),
|
|
501
|
-
needs_closure=True,
|
|
502
|
-
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
503
|
-
sphere_steps=10, sphere_loss=1e-5,
|
|
504
|
-
)
|
|
505
|
-
|
|
506
497
|
# ----------------------------------- lr/lr ---------------------------------- #
|
|
507
498
|
LR = Run(
|
|
508
499
|
func_opt=lambda p: tz.Modular(p, tz.m.LR(0.1)),
|
|
@@ -587,8 +578,8 @@ UpdateGradientSignConsistency = Run(
|
|
|
587
578
|
sphere_steps=10, sphere_loss=2,
|
|
588
579
|
)
|
|
589
580
|
IntermoduleCautious = Run(
|
|
590
|
-
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.01)),
|
|
591
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS()), tz.m.LR(0.1)),
|
|
581
|
+
func_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.01)),
|
|
582
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.IntermoduleCautious(tz.m.NAG(), tz.m.BFGS(ptol_reset=True)), tz.m.LR(0.1)),
|
|
592
583
|
needs_closure=False,
|
|
593
584
|
func='booth', steps=50, loss=1e-4, merge_invariant=True,
|
|
594
585
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -601,8 +592,8 @@ ScaleByGradCosineSimilarity = Run(
|
|
|
601
592
|
sphere_steps=10, sphere_loss=0.1,
|
|
602
593
|
)
|
|
603
594
|
ScaleModulesByCosineSimilarity = Run(
|
|
604
|
-
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.05)),
|
|
605
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS()),tz.m.LR(0.1)),
|
|
595
|
+
func_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.05)),
|
|
596
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.ScaleModulesByCosineSimilarity(tz.m.HeavyBall(0.9), tz.m.BFGS(ptol_reset=True)),tz.m.LR(0.1)),
|
|
606
597
|
needs_closure=False,
|
|
607
598
|
func='booth', steps=50, loss=0.005, merge_invariant=True,
|
|
608
599
|
sphere_steps=10, sphere_loss=0.1,
|
|
@@ -684,8 +675,8 @@ GradAccumulation = Run(
|
|
|
684
675
|
sphere_steps=20, sphere_loss=1e-11,
|
|
685
676
|
)
|
|
686
677
|
NegateOnLossIncrease = Run(
|
|
687
|
-
func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(),),
|
|
688
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(),),
|
|
678
|
+
func_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.02), tz.m.NegateOnLossIncrease(True),),
|
|
679
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.HeavyBall(), tz.m.LR(0.1), tz.m.NegateOnLossIncrease(True),),
|
|
689
680
|
needs_closure=True,
|
|
690
681
|
func='booth', steps=50, loss=0.1, merge_invariant=True,
|
|
691
682
|
sphere_steps=20, sphere_loss=0.001,
|
|
@@ -693,7 +684,7 @@ NegateOnLossIncrease = Run(
|
|
|
693
684
|
# -------------------------------- misc/switch ------------------------------- #
|
|
694
685
|
Alternate = Run(
|
|
695
686
|
func_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
|
|
696
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(1)),
|
|
687
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Alternate(tz.m.Adagrad(), tz.m.Adam(), tz.m.RMSprop()), tz.m.LR(0.1)),
|
|
697
688
|
needs_closure=False,
|
|
698
689
|
func='booth', steps=50, loss=1, merge_invariant=True,
|
|
699
690
|
sphere_steps=20, sphere_loss=20,
|
|
@@ -734,24 +725,24 @@ Shampoo = Run(
|
|
|
734
725
|
|
|
735
726
|
# ------------------------- quasi_newton/quasi_newton ------------------------ #
|
|
736
727
|
BFGS = Run(
|
|
737
|
-
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
738
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
728
|
+
func_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
729
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
739
730
|
needs_closure=True,
|
|
740
|
-
func='rosen', steps=50, loss=
|
|
741
|
-
sphere_steps=10, sphere_loss=
|
|
731
|
+
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
732
|
+
sphere_steps=10, sphere_loss=1e-10,
|
|
742
733
|
)
|
|
743
734
|
SR1 = Run(
|
|
744
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
|
|
745
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(), tz.m.StrongWolfe()),
|
|
735
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
|
|
736
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(ptol_reset=True), tz.m.StrongWolfe()),
|
|
746
737
|
needs_closure=True,
|
|
747
738
|
func='rosen', steps=50, loss=1e-12, merge_invariant=True,
|
|
748
739
|
sphere_steps=10, sphere_loss=0,
|
|
749
740
|
)
|
|
750
741
|
SSVM = Run(
|
|
751
|
-
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
752
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
742
|
+
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
|
|
743
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1, ptol_reset=True), tz.m.StrongWolfe()),
|
|
753
744
|
needs_closure=True,
|
|
754
|
-
func='rosen', steps=50, loss=
|
|
745
|
+
func='rosen', steps=50, loss=0.5, merge_invariant=True,
|
|
755
746
|
sphere_steps=10, sphere_loss=0,
|
|
756
747
|
)
|
|
757
748
|
|
|
@@ -766,21 +757,21 @@ LBFGS = Run(
|
|
|
766
757
|
|
|
767
758
|
# ----------------------------- quasi_newton/lsr1 ---------------------------- #
|
|
768
759
|
LSR1 = Run(
|
|
769
|
-
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
|
|
770
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(), tz.m.StrongWolfe()),
|
|
760
|
+
func_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
|
|
761
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.LSR1(scale_second=True), tz.m.StrongWolfe()),
|
|
771
762
|
needs_closure=True,
|
|
772
763
|
func='rosen', steps=50, loss=0, merge_invariant=True,
|
|
773
764
|
sphere_steps=10, sphere_loss=0,
|
|
774
765
|
)
|
|
775
766
|
|
|
776
|
-
# ---------------------------- quasi_newton/olbfgs --------------------------- #
|
|
777
|
-
OnlineLBFGS = Run(
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
)
|
|
767
|
+
# # ---------------------------- quasi_newton/olbfgs --------------------------- #
|
|
768
|
+
# OnlineLBFGS = Run(
|
|
769
|
+
# func_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
|
|
770
|
+
# sphere_opt=lambda p: tz.Modular(p, tz.m.OnlineLBFGS(), tz.m.StrongWolfe()),
|
|
771
|
+
# needs_closure=True,
|
|
772
|
+
# func='rosen', steps=50, loss=0, merge_invariant=True,
|
|
773
|
+
# sphere_steps=10, sphere_loss=0,
|
|
774
|
+
# )
|
|
784
775
|
|
|
785
776
|
# ---------------------------- second_order/newton --------------------------- #
|
|
786
777
|
Newton = Run(
|
|
@@ -802,8 +793,8 @@ NewtonCG = Run(
|
|
|
802
793
|
|
|
803
794
|
# ---------------------------- smoothing/gaussian ---------------------------- #
|
|
804
795
|
GaussianHomotopy = Run(
|
|
805
|
-
func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
806
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(), tz.m.StrongWolfe()),
|
|
796
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
797
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GaussianHomotopy(10, 1, tol=1e-1, seed=0), tz.m.BFGS(ptol_reset=True), tz.m.StrongWolfe()),
|
|
807
798
|
needs_closure=True,
|
|
808
799
|
func='booth', steps=20, loss=0.1, merge_invariant=True,
|
|
809
800
|
sphere_steps=10, sphere_loss=200,
|
|
@@ -860,7 +851,7 @@ SophiaH = Run(
|
|
|
860
851
|
sphere_steps=10, sphere_loss=40,
|
|
861
852
|
)
|
|
862
853
|
|
|
863
|
-
# --------------------------
|
|
854
|
+
# -------------------------- higher_order ------------------------- #
|
|
864
855
|
HigherOrderNewton = Run(
|
|
865
856
|
func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
|
|
866
857
|
sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
|
|
@@ -869,6 +860,24 @@ HigherOrderNewton = Run(
|
|
|
869
860
|
sphere_steps=1, sphere_loss=1e-10,
|
|
870
861
|
)
|
|
871
862
|
|
|
863
|
+
# ---------------------------- optimizers/ladagrad --------------------------- #
|
|
864
|
+
LMAdagrad = Run(
|
|
865
|
+
func_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(4)),
|
|
866
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.LMAdagrad(), tz.m.LR(5)),
|
|
867
|
+
needs_closure=False,
|
|
868
|
+
func='booth', steps=50, loss=1e-6, merge_invariant=True,
|
|
869
|
+
sphere_steps=20, sphere_loss=1e-9,
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
# ------------------------------ optimizers/adan ----------------------------- #
|
|
873
|
+
Adan = Run(
|
|
874
|
+
func_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(1)),
|
|
875
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.Adan(), tz.m.LR(0.1)),
|
|
876
|
+
needs_closure=False,
|
|
877
|
+
func='booth', steps=50, loss=60, merge_invariant=True,
|
|
878
|
+
sphere_steps=20, sphere_loss=60,
|
|
879
|
+
)
|
|
880
|
+
|
|
872
881
|
# ------------------------------------ CGs ----------------------------------- #
|
|
873
882
|
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
|
|
874
883
|
for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
|
|
@@ -885,17 +894,22 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
|
|
|
885
894
|
# stability test
|
|
886
895
|
for QN in (
|
|
887
896
|
tz.m.BFGS,
|
|
897
|
+
partial(tz.m.BFGS, inverse=False),
|
|
888
898
|
tz.m.SR1,
|
|
899
|
+
partial(tz.m.SR1, inverse=False),
|
|
889
900
|
tz.m.DFP,
|
|
901
|
+
partial(tz.m.DFP, inverse=False),
|
|
890
902
|
tz.m.BroydenGood,
|
|
903
|
+
partial(tz.m.BroydenGood, inverse=False),
|
|
891
904
|
tz.m.BroydenBad,
|
|
905
|
+
partial(tz.m.BroydenBad, inverse=False),
|
|
892
906
|
tz.m.Greenstadt1,
|
|
893
907
|
tz.m.Greenstadt2,
|
|
894
|
-
tz.m.
|
|
908
|
+
tz.m.ICUM,
|
|
895
909
|
tz.m.ThomasOptimalMethod,
|
|
896
910
|
tz.m.FletcherVMM,
|
|
897
911
|
tz.m.Horisho,
|
|
898
|
-
|
|
912
|
+
partial(tz.m.Horisho, inner=tz.m.GradientCorrection()),
|
|
899
913
|
tz.m.Pearson,
|
|
900
914
|
tz.m.ProjectedNewtonRaphson,
|
|
901
915
|
tz.m.PSB,
|
|
@@ -903,8 +917,8 @@ for QN in (
|
|
|
903
917
|
tz.m.SSVM,
|
|
904
918
|
):
|
|
905
919
|
Run(
|
|
906
|
-
func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
907
|
-
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
920
|
+
func_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
|
|
921
|
+
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False, ptol_reset=True), tz.m.StrongWolfe()),
|
|
908
922
|
needs_closure=True,
|
|
909
923
|
func='lstsq', steps=50, loss=1e-10, merge_invariant=False,
|
|
910
924
|
sphere_steps=10, sphere_loss=1e-20,
|
tests/test_vars.py
CHANGED
|
@@ -156,6 +156,7 @@ def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
|
|
|
156
156
|
for k,v in v1.__dict__.items():
|
|
157
157
|
if not k.startswith('__'):
|
|
158
158
|
# if k == 'post_step_hooks': continue
|
|
159
|
+
if k == 'storage': continue
|
|
159
160
|
if k == 'update' and clone_update:
|
|
160
161
|
if v1.update is None or v2.update is None:
|
|
161
162
|
assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
|
torchzero/core/module.py
CHANGED
|
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
4
|
from collections.abc import Callable, Iterable, MutableMapping, Sequence
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any, final, overload
|
|
6
|
+
from typing import Any, final, overload, Literal
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
@@ -14,6 +14,7 @@ from ..utils import (
|
|
|
14
14
|
_make_param_groups,
|
|
15
15
|
get_state_vals,
|
|
16
16
|
)
|
|
17
|
+
from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
17
18
|
from ..utils.python_tools import flatten
|
|
18
19
|
|
|
19
20
|
|
|
@@ -109,6 +110,9 @@ class Var:
|
|
|
109
110
|
self.skip_update: bool = False
|
|
110
111
|
"""if True, the parameters will not be updated"""
|
|
111
112
|
|
|
113
|
+
self.storage: dict = {}
|
|
114
|
+
"""Storage for any other data, such as hessian estimates, etc"""
|
|
115
|
+
|
|
112
116
|
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
113
117
|
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`var.loss`.
|
|
114
118
|
Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
|
|
@@ -186,6 +190,7 @@ class Var:
|
|
|
186
190
|
if self.loss is None: self.loss = var.loss
|
|
187
191
|
if self.loss_approx is None: self.loss_approx = var.loss_approx
|
|
188
192
|
if self.grad is None: self.grad = var.grad
|
|
193
|
+
self.storage.update(var.storage)
|
|
189
194
|
|
|
190
195
|
def zero_grad(self, set_to_none=True):
|
|
191
196
|
if set_to_none:
|
|
@@ -358,6 +363,26 @@ class Module(ABC):
|
|
|
358
363
|
# # if isinstance(params, Vars): params = params.params
|
|
359
364
|
# return itemgetter(*keys)(self.settings[params[0]])
|
|
360
365
|
|
|
366
|
+
def clear_state_keys(self, *keys:str):
|
|
367
|
+
for s in self.state.values():
|
|
368
|
+
for k in keys:
|
|
369
|
+
if k in s: del s[k]
|
|
370
|
+
|
|
371
|
+
@overload
|
|
372
|
+
def store(self, params: Sequence[torch.Tensor], keys: str, values: Sequence): ...
|
|
373
|
+
@overload
|
|
374
|
+
def store(self, params: Sequence[torch.Tensor], keys: Sequence[str], values: Sequence[Sequence]): ...
|
|
375
|
+
def store(self, params: Sequence[torch.Tensor], keys: str | Sequence[str], values: Sequence):
|
|
376
|
+
if isinstance(keys, str):
|
|
377
|
+
for p,v in zip(params, values):
|
|
378
|
+
state = self.state[p]
|
|
379
|
+
state[keys] = v
|
|
380
|
+
return
|
|
381
|
+
|
|
382
|
+
for p, *p_v in zip(params, *values):
|
|
383
|
+
state = self.state[p]
|
|
384
|
+
for k,v in zip(keys, p_v): state[k] = v
|
|
385
|
+
|
|
361
386
|
def state_dict(self):
|
|
362
387
|
"""state dict"""
|
|
363
388
|
packed_state = {id(k):v for k,v in self.state.items()}
|
|
@@ -403,23 +428,111 @@ class Module(ABC):
|
|
|
403
428
|
self._extra_unpack(state_dict['extra'])
|
|
404
429
|
|
|
405
430
|
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
406
|
-
@abstractmethod
|
|
407
431
|
def step(self, var: Var) -> Var:
|
|
408
|
-
"""performs a step, returns new var but may update
|
|
432
|
+
"""performs a step, returns new var but may update it in-place."""
|
|
433
|
+
self.update(var)
|
|
434
|
+
return self.apply(var)
|
|
435
|
+
|
|
436
|
+
def update(self, var:Var) -> Any:
|
|
437
|
+
"""Updates the internal state of this module. This should not modify `var.update`.
|
|
438
|
+
|
|
439
|
+
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
440
|
+
such as ::code::`tz.m.Online`.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
def apply(self, var: Var) -> Var:
|
|
444
|
+
"""Applies this module to ``var.get_update()``. This should not modify the internal state of this module if possible."""
|
|
445
|
+
raise NotImplementedError(f"{self} doesn't implement the `apply` method.")
|
|
409
446
|
|
|
410
447
|
def reset(self):
|
|
411
|
-
"""Resets the internal state of the module (e.g. momentum)."""
|
|
448
|
+
"""Resets the internal state of the module (e.g. momentum). By default clears state and global state."""
|
|
412
449
|
# no complex logic is allowed there because this is overridden by many modules
|
|
413
450
|
# where super().reset() shouldn't be called
|
|
414
451
|
self.state.clear()
|
|
415
452
|
self.global_state.clear()
|
|
416
453
|
|
|
454
|
+
def reset_for_online(self):
|
|
455
|
+
"""resets only the intermediate state of this module, e.g. previous parameters and gradient."""
|
|
456
|
+
for c in self.children.values(): c.reset_for_online()
|
|
457
|
+
|
|
417
458
|
def _extra_pack(self):
|
|
418
459
|
return {}
|
|
419
460
|
|
|
420
461
|
def _extra_unpack(self, x):
|
|
421
462
|
pass
|
|
422
463
|
|
|
464
|
+
|
|
465
|
+
# ------------------------------ HELPER METHODS ------------------------------ #
|
|
466
|
+
@torch.no_grad
|
|
467
|
+
def Hvp(
|
|
468
|
+
self,
|
|
469
|
+
v: Sequence[torch.Tensor],
|
|
470
|
+
at_x0: bool,
|
|
471
|
+
var: Var,
|
|
472
|
+
rgrad: Sequence[torch.Tensor] | None,
|
|
473
|
+
hvp_method: Literal['autograd', 'forward', 'central'],
|
|
474
|
+
h: float,
|
|
475
|
+
normalize: bool,
|
|
476
|
+
retain_grad: bool,
|
|
477
|
+
):
|
|
478
|
+
"""
|
|
479
|
+
Returns ``(Hvp, rgrad)``. ``rgrad`` is gradient at current parameters, possibly with create_graph=True, or it may be None with ``hvp_method="central"``. Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
|
|
480
|
+
|
|
481
|
+
Single sample example:
|
|
482
|
+
|
|
483
|
+
.. code:: py
|
|
484
|
+
|
|
485
|
+
Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
|
|
486
|
+
|
|
487
|
+
Multiple samples example:
|
|
488
|
+
|
|
489
|
+
.. code:: py
|
|
490
|
+
|
|
491
|
+
D = None
|
|
492
|
+
rgrad = None
|
|
493
|
+
for i in range(n_samples):
|
|
494
|
+
v = [torch.randn_like(p) for p in params]
|
|
495
|
+
Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
|
|
496
|
+
|
|
497
|
+
if D is None: D = Hvp
|
|
498
|
+
else: torch._foreach_add_(D, Hvp)
|
|
499
|
+
|
|
500
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
501
|
+
Args:
|
|
502
|
+
v (Sequence[torch.Tensor]): vector in hessian-vector product
|
|
503
|
+
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
504
|
+
var (Var): Var
|
|
505
|
+
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
506
|
+
hvp_method (str): hvp method.
|
|
507
|
+
h (float): finite difference step size
|
|
508
|
+
normalize (bool): whether to normalize v for finite difference
|
|
509
|
+
retain_grad (bool): retain grad
|
|
510
|
+
"""
|
|
511
|
+
# get grad
|
|
512
|
+
if rgrad is None and hvp_method in ('autograd', 'forward'):
|
|
513
|
+
if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
|
|
514
|
+
else:
|
|
515
|
+
if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
|
|
516
|
+
with torch.enable_grad():
|
|
517
|
+
loss = var.closure()
|
|
518
|
+
rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
|
|
519
|
+
|
|
520
|
+
if hvp_method == 'autograd':
|
|
521
|
+
assert rgrad is not None
|
|
522
|
+
Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
|
|
523
|
+
|
|
524
|
+
elif hvp_method == 'forward':
|
|
525
|
+
assert rgrad is not None
|
|
526
|
+
loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
|
|
527
|
+
|
|
528
|
+
elif hvp_method == 'central':
|
|
529
|
+
loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
|
|
530
|
+
|
|
531
|
+
else:
|
|
532
|
+
raise ValueError(hvp_method)
|
|
533
|
+
|
|
534
|
+
return Hvp, rgrad
|
|
535
|
+
|
|
423
536
|
# endregion
|
|
424
537
|
|
|
425
538
|
Chainable = Module | Sequence[Module]
|
|
@@ -440,6 +553,21 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
|
|
|
440
553
|
|
|
441
554
|
# region Modular
|
|
442
555
|
# ---------------------------------- Modular --------------------------------- #
|
|
556
|
+
|
|
557
|
+
class _EvalCounterClosure:
|
|
558
|
+
"""keeps track of how many times closure has been evaluated"""
|
|
559
|
+
__slots__ = ("modular", "closure")
|
|
560
|
+
def __init__(self, modular: "Modular", closure):
|
|
561
|
+
self.modular = modular
|
|
562
|
+
self.closure = closure
|
|
563
|
+
|
|
564
|
+
def __call__(self, *args, **kwargs):
|
|
565
|
+
if self.closure is None:
|
|
566
|
+
raise RuntimeError("One of the modules requires closure to be passed to the step method")
|
|
567
|
+
|
|
568
|
+
self.modular.num_evaluations += 1
|
|
569
|
+
return self.closure(*args, **kwargs)
|
|
570
|
+
|
|
443
571
|
# have to inherit from Modular to support lr schedulers
|
|
444
572
|
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
445
573
|
class Modular(torch.optim.Optimizer):
|
|
@@ -496,7 +624,10 @@ class Modular(torch.optim.Optimizer):
|
|
|
496
624
|
# self.add_param_group(param_group)
|
|
497
625
|
|
|
498
626
|
self.current_step = 0
|
|
499
|
-
"""
|
|
627
|
+
"""global step counter for the optimizer."""
|
|
628
|
+
|
|
629
|
+
self.num_evaluations = 0
|
|
630
|
+
"""number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
|
|
500
631
|
|
|
501
632
|
def add_param_group(self, param_group: dict[str, Any]):
|
|
502
633
|
proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
|
|
@@ -558,11 +689,12 @@ class Modular(torch.optim.Optimizer):
|
|
|
558
689
|
|
|
559
690
|
# create var
|
|
560
691
|
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
561
|
-
var = Var(params=params, closure=closure, model=self.model, current_step=self.current_step)
|
|
692
|
+
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step)
|
|
562
693
|
|
|
563
694
|
# if closure is None, assume backward has been called and gather grads
|
|
564
695
|
if closure is None:
|
|
565
696
|
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
697
|
+
self.num_evaluations += 1
|
|
566
698
|
|
|
567
699
|
last_module = self.modules[-1]
|
|
568
700
|
last_lr = last_module.defaults.get('lr', None)
|