torchzero 0.3.10__tar.gz → 0.3.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.3.10 → torchzero-0.3.11}/PKG-INFO +65 -40
- {torchzero-0.3.10 → torchzero-0.3.11}/README.md +64 -39
- {torchzero-0.3.10 → torchzero-0.3.11}/docs/source/conf.py +6 -4
- torchzero-0.3.11/docs/source/docstring template.py +46 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/pyproject.toml +2 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/tests/test_identical.py +2 -3
- {torchzero-0.3.10 → torchzero-0.3.11}/tests/test_opts.py +64 -50
- {torchzero-0.3.10 → torchzero-0.3.11}/tests/test_vars.py +1 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/core/module.py +138 -6
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/core/transform.py +158 -51
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/__init__.py +3 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/clipping/clipping.py +114 -17
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/clipping/ema_clipping.py +27 -13
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero-0.3.11/torchzero/modules/experimental/__init__.py +41 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/absoap.py +5 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/adadam.py +8 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/adamY.py +8 -2
- torchzero-0.3.11/torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero-0.3.10/torchzero/modules/line_search/trust_region.py → torchzero-0.3.11/torchzero/modules/experimental/adaptive_step_size.py +21 -4
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/adasoap.py +7 -2
- torchzero-0.3.11/torchzero/modules/experimental/cosine.py +214 -0
- torchzero-0.3.11/torchzero/modules/experimental/cubic_adam.py +97 -0
- {torchzero-0.3.10/torchzero/modules/projections → torchzero-0.3.11/torchzero/modules/experimental}/dct.py +11 -11
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/eigendescent.py +4 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/etf.py +32 -9
- torchzero-0.3.11/torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero-0.3.11/torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- {torchzero-0.3.10/torchzero/modules/projections → torchzero-0.3.11/torchzero/modules/experimental}/fft.py +10 -10
- torchzero-0.3.11/torchzero/modules/experimental/hnewton.py +85 -0
- {torchzero-0.3.10/torchzero/modules/quasi_newton → torchzero-0.3.11/torchzero/modules}/experimental/modular_lbfgs.py +27 -28
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero-0.3.11/torchzero/modules/experimental/parabolic_search.py +220 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero-0.3.10/torchzero/modules/projections/structural.py → torchzero-0.3.11/torchzero/modules/experimental/structural_projections.py +12 -54
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero-0.3.10/torchzero/modules/experimental/tada.py → torchzero-0.3.11/torchzero/modules/experimental/tensor_adagrad.py +10 -6
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/functional.py +12 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/grad_approximation/fdm.py +30 -3
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero-0.3.11/torchzero/modules/grad_approximation/rfdm.py +519 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero-0.3.11/torchzero/modules/line_search/__init__.py +5 -0
- torchzero-0.3.11/torchzero/modules/line_search/adaptive.py +99 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/line_search/backtracking.py +34 -9
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/line_search/line_search.py +70 -12
- torchzero-0.3.11/torchzero/modules/line_search/polynomial.py +233 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/line_search/scipy.py +2 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero-0.3.11/torchzero/modules/misc/__init__.py +27 -0
- {torchzero-0.3.10/torchzero/modules/ops → torchzero-0.3.11/torchzero/modules/misc}/debug.py +24 -1
- torchzero-0.3.11/torchzero/modules/misc/escape.py +60 -0
- torchzero-0.3.11/torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero-0.3.11/torchzero/modules/misc/misc.py +316 -0
- torchzero-0.3.11/torchzero/modules/misc/multistep.py +158 -0
- torchzero-0.3.11/torchzero/modules/misc/regularization.py +171 -0
- {torchzero-0.3.10/torchzero/modules/ops → torchzero-0.3.11/torchzero/modules/misc}/split.py +29 -1
- {torchzero-0.3.10/torchzero/modules/ops → torchzero-0.3.11/torchzero/modules/misc}/switch.py +44 -3
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/momentum/__init__.py +1 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/momentum/averaging.py +6 -6
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/momentum/cautious.py +45 -8
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/momentum/ema.py +7 -7
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/momentum/experimental.py +2 -2
- torchzero-0.3.11/torchzero/modules/momentum/matrix_momentum.py +193 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/momentum/momentum.py +2 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/ops/__init__.py +3 -31
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/ops/accumulate.py +6 -10
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/ops/binary.py +72 -26
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/ops/multi.py +77 -16
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/ops/reduce.py +15 -7
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/ops/unary.py +29 -13
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/ops/utility.py +20 -12
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/__init__.py +12 -3
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero-0.3.11/torchzero/modules/optimizers/adahessian.py +223 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/adam.py +7 -6
- torchzero-0.3.11/torchzero/modules/optimizers/adan.py +110 -0
- torchzero-0.3.11/torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero-0.3.11/torchzero/modules/optimizers/esgd.py +171 -0
- torchzero-0.3.10/torchzero/modules/experimental/spectral.py → torchzero-0.3.11/torchzero/modules/optimizers/ladagrad.py +91 -71
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/lion.py +1 -1
- torchzero-0.3.11/torchzero/modules/optimizers/mars.py +91 -0
- torchzero-0.3.11/torchzero/modules/optimizers/msam.py +186 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/muon.py +30 -5
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/orthograd.py +1 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/rmsprop.py +7 -4
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/rprop.py +42 -8
- torchzero-0.3.11/torchzero/modules/optimizers/sam.py +163 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/shampoo.py +39 -5
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/soap.py +29 -19
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero-0.3.11/torchzero/modules/projections/__init__.py +3 -0
- torchzero-0.3.11/torchzero/modules/projections/cast.py +51 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/projections/galore.py +3 -1
- torchzero-0.3.11/torchzero/modules/projections/projection.py +338 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/quasi_newton/__init__.py +12 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero-0.3.11/torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero-0.3.11/torchzero/modules/quasi_newton/lbfgs.py +286 -0
- torchzero-0.3.11/torchzero/modules/quasi_newton/lsr1.py +218 -0
- torchzero-0.3.11/torchzero/modules/quasi_newton/quasi_newton.py +1331 -0
- torchzero-0.3.11/torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero-0.3.11/torchzero/modules/second_order/__init__.py +3 -0
- torchzero-0.3.11/torchzero/modules/second_order/newton.py +338 -0
- torchzero-0.3.11/torchzero/modules/second_order/newton_cg.py +374 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/second_order/nystrom.py +104 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/smoothing/gaussian.py +34 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero-0.3.11/torchzero/modules/step_size/__init__.py +2 -0
- torchzero-0.3.11/torchzero/modules/step_size/adaptive.py +122 -0
- torchzero-0.3.11/torchzero/modules/step_size/lr.py +154 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero-0.3.11/torchzero/modules/weight_decay/weight_decay.py +168 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/wrappers/optim_wrapper.py +29 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/directsearch.py +39 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/fcmaes.py +21 -13
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/mads.py +5 -6
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/nevergrad.py +16 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/optuna.py +1 -1
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/scipy.py +5 -3
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/__init__.py +2 -2
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/derivatives.py +3 -3
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/linalg/__init__.py +1 -1
- torchzero-0.3.11/torchzero/utils/linalg/solve.py +408 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/numberlist.py +2 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/python_tools.py +10 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero.egg-info/PKG-INFO +65 -40
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero.egg-info/SOURCES.txt +39 -19
- torchzero-0.3.10/torchzero/modules/experimental/__init__.py +0 -24
- torchzero-0.3.10/torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero-0.3.10/torchzero/modules/experimental/soapy.py +0 -163
- torchzero-0.3.10/torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero-0.3.10/torchzero/modules/grad_approximation/rfdm.py +0 -272
- torchzero-0.3.10/torchzero/modules/line_search/__init__.py +0 -5
- torchzero-0.3.10/torchzero/modules/lr/__init__.py +0 -2
- torchzero-0.3.10/torchzero/modules/lr/adaptive.py +0 -93
- torchzero-0.3.10/torchzero/modules/lr/lr.py +0 -63
- torchzero-0.3.10/torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero-0.3.10/torchzero/modules/ops/misc.py +0 -418
- torchzero-0.3.10/torchzero/modules/projections/__init__.py +0 -5
- torchzero-0.3.10/torchzero/modules/projections/projection.py +0 -244
- torchzero-0.3.10/torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero-0.3.10/torchzero/modules/quasi_newton/lbfgs.py +0 -229
- torchzero-0.3.10/torchzero/modules/quasi_newton/lsr1.py +0 -174
- torchzero-0.3.10/torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10/torchzero/modules/quasi_newton/quasi_newton.py +0 -683
- torchzero-0.3.10/torchzero/modules/second_order/__init__.py +0 -3
- torchzero-0.3.10/torchzero/modules/second_order/newton.py +0 -159
- torchzero-0.3.10/torchzero/modules/second_order/newton_cg.py +0 -85
- torchzero-0.3.10/torchzero/modules/weight_decay/weight_decay.py +0 -86
- torchzero-0.3.10/torchzero/utils/linalg/solve.py +0 -169
- {torchzero-0.3.10 → torchzero-0.3.11}/LICENSE +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/setup.cfg +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/tests/test_module.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/tests/test_tensorlist.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/core/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/curveball.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/gradmin.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/experimental/newton_solver.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/higher_order/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/compile.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/linalg/benchmark.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/linalg/matrix_funcs.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/linalg/orthogonalize.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/linalg/qr.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/linalg/svd.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/ops.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/params.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.3.10 → torchzero-0.3.11}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.11
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -45,8 +45,6 @@ Dynamic: license-file
|
|
|
45
45
|
|
|
46
46
|
`torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
|
|
47
47
|
|
|
48
|
-
NOTE: torchzero is in active development, currently docs are in a state of flux.
|
|
49
|
-
|
|
50
48
|
## Installation
|
|
51
49
|
|
|
52
50
|
```bash
|
|
@@ -113,31 +111,21 @@ for epoch in range(100):
|
|
|
113
111
|
`torchzero` provides a huge number of various modules:
|
|
114
112
|
|
|
115
113
|
* **Optimizers**: Optimization algorithms.
|
|
116
|
-
* `Adam`.
|
|
117
|
-
* `Shampoo`.
|
|
118
|
-
* `SOAP` (my current recommendation).
|
|
119
|
-
* `Muon`.
|
|
120
|
-
* `SophiaH`.
|
|
121
|
-
* `Adagrad` and `FullMatrixAdagrad`.
|
|
122
|
-
* `Lion`.
|
|
123
|
-
* `RMSprop`.
|
|
124
|
-
* `OrthoGrad`.
|
|
125
|
-
* `Rprop`.
|
|
114
|
+
* `Adam`, `Adan`, `Adagrad`, `ESGD`, `FullMatrixAdagrad`, `LMAdagrad`, `AdaHessian`, `AdaptiveHeavyBall`, `OrthoGrad`, `Lion`, `MARS`, `MatrixMomentum`, `AdaptiveMatrixMomentum`, `Muon`, `RMSprop`, `Rprop`, `SAM`, `ASAM`, `MSAM`, `Shampoo`, `SOAP`, `SophiaH`.
|
|
126
115
|
|
|
127
116
|
Additionally many other optimizers can be easily defined via modules:
|
|
128
117
|
* Grams: `[tz.m.Adam(), tz.m.GradSign()]`
|
|
129
118
|
* LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
|
|
130
119
|
* Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
|
|
131
|
-
*
|
|
120
|
+
* Efficient full-matrix version of any diagonal optimizer, like Adam: `[tz.m.LMAdagrad(beta=0.999, inner=tz.m.EMA(0.9)), tz.m.Debias(0.9, 0.999)]`
|
|
132
121
|
* Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
|
|
133
122
|
|
|
134
123
|
* **Momentum**:
|
|
135
|
-
* `NAG`: Nesterov Accelerated Gradient.
|
|
136
124
|
* `HeavyBall`: Classic momentum (Polyak's momentum).
|
|
125
|
+
* `NAG`: Nesterov Accelerated Gradient.
|
|
137
126
|
* `EMA`: Exponential moving average.
|
|
138
|
-
* `Averaging` (`
|
|
127
|
+
* `Averaging` (`MedianAveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
|
|
139
128
|
* `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
|
|
140
|
-
* `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
|
|
141
129
|
|
|
142
130
|
* **Stabilization**: Gradient stabilization techniques.
|
|
143
131
|
* `ClipNorm`: Clips gradient L2 norm.
|
|
@@ -154,32 +142,42 @@ for epoch in range(100):
|
|
|
154
142
|
|
|
155
143
|
* **Second order**: Second order methods.
|
|
156
144
|
* `Newton`: Classic Newton's method.
|
|
157
|
-
* `
|
|
145
|
+
* `InverseFreeNewton`: Inverse-free version of Newton's method.
|
|
146
|
+
* `NewtonCG`: Matrix-free newton's method with conjugate gradient or minimal residual solvers.
|
|
147
|
+
* `TruncatedNewtonCG`: Steihaug-Toint Trust-region NewtonCG via a truncated CG solver.
|
|
158
148
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning
|
|
149
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning.
|
|
160
150
|
* `HigherOrderNewton`: Higher order Newton's method with trust region.
|
|
161
151
|
|
|
162
152
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
163
153
|
* `LBFGS`: Limited-memory BFGS.
|
|
164
154
|
* `LSR1`: Limited-memory SR1.
|
|
165
155
|
* `OnlineLBFGS`: Online LBFGS.
|
|
166
|
-
* `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `
|
|
156
|
+
* `BFGS`, `DFP`, `ICUM`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `NewSSM`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`, `ShorR`: Full-matrix quasi-newton methods.
|
|
157
|
+
* `DiagonalBFGS`, `DiagonalSR1`, `DiagonalQuasiCauchi`, `DiagonalWeightedQuasiCauchi`, `DNRTR`, `NewDQN`: Diagonal quasi-newton methods.
|
|
167
158
|
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
|
|
168
159
|
|
|
160
|
+
* **Trust Region** Trust region can work with exact hessian or any of the quasi-newton methods (L-BFGS support is WIP)
|
|
161
|
+
* `TrustCG`: Trust-region, uses a Steihaug-Toint truncated CG solver.
|
|
162
|
+
* `CubicRegularization`: Cubic regularization, works better with exact hessian.
|
|
163
|
+
|
|
169
164
|
* **Line Search**:
|
|
170
165
|
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
171
166
|
* `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
|
|
172
167
|
* `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
|
|
173
|
-
* `TrustRegion`: First order trust region method.
|
|
174
168
|
|
|
175
169
|
* **Learning Rate**:
|
|
176
170
|
* `LR`: Controls learning rate and adds support for LR schedulers.
|
|
177
|
-
* `PolyakStepSize`: Polyak's method.
|
|
178
|
-
* `
|
|
171
|
+
* `PolyakStepSize`: Polyak's subgradient method.
|
|
172
|
+
* `BarzilaiBorwein`: Barzilai-Borwein step-size.
|
|
173
|
+
* `Warmup`, `WarmupNormCLip`: Learning rate warmup.
|
|
179
174
|
|
|
180
175
|
* **Projections**: This can implement things like GaLore but I haven't done that yet.
|
|
181
|
-
* `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
182
|
-
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.).
|
|
176
|
+
<!-- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
177
|
+
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.). -->
|
|
178
|
+
This is WIP
|
|
179
|
+
* `To`: this casts everything to any other dtype and device for other modules, e.g. if you want better precision
|
|
180
|
+
* `ViewAsReal`: put if you have complex paramters.
|
|
183
181
|
|
|
184
182
|
* **Smoothing**: Smoothing-based optimization methods.
|
|
185
183
|
* `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
|
|
@@ -195,6 +193,8 @@ for epoch in range(100):
|
|
|
195
193
|
|
|
196
194
|
* **Experimental**: various horrible atrocities
|
|
197
195
|
|
|
196
|
+
A complete list of modules is available in the [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
|
|
197
|
+
|
|
198
198
|
## Advanced Usage
|
|
199
199
|
|
|
200
200
|
### Closure
|
|
@@ -321,6 +321,7 @@ class HeavyBall(Module):
|
|
|
321
321
|
super().__init__(defaults)
|
|
322
322
|
|
|
323
323
|
def step(self, var: Var):
|
|
324
|
+
# Var object holds all attributes used for optimization - parameters, gradient, update, etc.
|
|
324
325
|
# a module takes a Var object, modifies it or creates a new one, and returns it
|
|
325
326
|
# Var has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
326
327
|
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
@@ -352,28 +353,52 @@ class HeavyBall(Module):
|
|
|
352
353
|
return var
|
|
353
354
|
```
|
|
354
355
|
|
|
355
|
-
|
|
356
|
+
More in-depth guide will be available in the documentation in the future.
|
|
357
|
+
|
|
358
|
+
## Other stuff
|
|
356
359
|
|
|
357
|
-
|
|
358
|
-
* `LineSearch` for line searches
|
|
359
|
-
* `Projection` for projections like GaLore or into fourier domain.
|
|
360
|
-
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
361
|
-
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
360
|
+
There are also wrappers providing `torch.optim.Optimizer` interface for various other libraries. When using those, make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
362
361
|
|
|
363
|
-
|
|
362
|
+
---
|
|
364
363
|
|
|
365
|
-
|
|
364
|
+
### Scipy
|
|
366
365
|
|
|
367
|
-
|
|
366
|
+
#### torchzero.optim.wrappers.scipy.ScipyMinimize
|
|
368
367
|
|
|
369
|
-
|
|
368
|
+
A wrapper for `scipy.optimize.minimize` with gradients and hessians supplied by pytorch autograd. Scipy provides implementations of the following methods: `'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'`.
|
|
370
369
|
|
|
371
|
-
|
|
370
|
+
#### torchzero.optim.wrappers.scipy.ScipyDE, ScipyDualAnnealing, ScipySHGO, ScipyDIRECT, ScipyBrute
|
|
372
371
|
|
|
373
|
-
|
|
372
|
+
Equivalent wrappers for other derivative free solvers available in `scipy.optimize`
|
|
373
|
+
|
|
374
|
+
---
|
|
375
|
+
|
|
376
|
+
### NLOpt
|
|
377
|
+
|
|
378
|
+
#### torchzero.optim.wrappers.nlopt.NLOptWrapper
|
|
374
379
|
|
|
375
|
-
|
|
380
|
+
A wrapper for [NLOpt](https://github.com/stevengj/nlopt) with gradients supplied by pytorch autograd. NLOpt is another popular library with many gradient based and gradient free [algorithms](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/)
|
|
381
|
+
|
|
382
|
+
---
|
|
383
|
+
|
|
384
|
+
### Nevergrad
|
|
385
|
+
|
|
386
|
+
#### torchzero.optim.wrappers.nevergrad.NevergradWrapper
|
|
387
|
+
|
|
388
|
+
A wrapper for [nevergrad](https://facebookresearch.github.io/nevergrad/) which has a huge library of gradient free [algorithms](https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers)
|
|
389
|
+
|
|
390
|
+
---
|
|
391
|
+
|
|
392
|
+
### fast-cma-es
|
|
393
|
+
|
|
394
|
+
#### torchzero.optim.wrappers.fcmaes.FcmaesWrapper
|
|
395
|
+
|
|
396
|
+
A wrapper for [fast-cma-es](https://github.com/dietmarwo/fast-cma-es), which implements various gradient free algorithms. Notably it includes [BITEOPT](https://github.com/avaneev/biteopt) which seems to have very good performance in benchmarks.
|
|
397
|
+
|
|
398
|
+
# License
|
|
399
|
+
|
|
400
|
+
This project is licensed under the MIT License
|
|
376
401
|
|
|
377
|
-
|
|
402
|
+
# Project Links
|
|
378
403
|
|
|
379
|
-
|
|
404
|
+
The documentation is available at <https://torchzero.readthedocs.io/en/latest/>
|
|
@@ -6,8 +6,6 @@
|
|
|
6
6
|
|
|
7
7
|
`torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
|
|
8
8
|
|
|
9
|
-
NOTE: torchzero is in active development, currently docs are in a state of flux.
|
|
10
|
-
|
|
11
9
|
## Installation
|
|
12
10
|
|
|
13
11
|
```bash
|
|
@@ -74,31 +72,21 @@ for epoch in range(100):
|
|
|
74
72
|
`torchzero` provides a huge number of various modules:
|
|
75
73
|
|
|
76
74
|
* **Optimizers**: Optimization algorithms.
|
|
77
|
-
* `Adam`.
|
|
78
|
-
* `Shampoo`.
|
|
79
|
-
* `SOAP` (my current recommendation).
|
|
80
|
-
* `Muon`.
|
|
81
|
-
* `SophiaH`.
|
|
82
|
-
* `Adagrad` and `FullMatrixAdagrad`.
|
|
83
|
-
* `Lion`.
|
|
84
|
-
* `RMSprop`.
|
|
85
|
-
* `OrthoGrad`.
|
|
86
|
-
* `Rprop`.
|
|
75
|
+
* `Adam`, `Adan`, `Adagrad`, `ESGD`, `FullMatrixAdagrad`, `LMAdagrad`, `AdaHessian`, `AdaptiveHeavyBall`, `OrthoGrad`, `Lion`, `MARS`, `MatrixMomentum`, `AdaptiveMatrixMomentum`, `Muon`, `RMSprop`, `Rprop`, `SAM`, `ASAM`, `MSAM`, `Shampoo`, `SOAP`, `SophiaH`.
|
|
87
76
|
|
|
88
77
|
Additionally many other optimizers can be easily defined via modules:
|
|
89
78
|
* Grams: `[tz.m.Adam(), tz.m.GradSign()]`
|
|
90
79
|
* LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
|
|
91
80
|
* Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
|
|
92
|
-
*
|
|
81
|
+
* Efficient full-matrix version of any diagonal optimizer, like Adam: `[tz.m.LMAdagrad(beta=0.999, inner=tz.m.EMA(0.9)), tz.m.Debias(0.9, 0.999)]`
|
|
93
82
|
* Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
|
|
94
83
|
|
|
95
84
|
* **Momentum**:
|
|
96
|
-
* `NAG`: Nesterov Accelerated Gradient.
|
|
97
85
|
* `HeavyBall`: Classic momentum (Polyak's momentum).
|
|
86
|
+
* `NAG`: Nesterov Accelerated Gradient.
|
|
98
87
|
* `EMA`: Exponential moving average.
|
|
99
|
-
* `Averaging` (`
|
|
88
|
+
* `Averaging` (`MedianAveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
|
|
100
89
|
* `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
|
|
101
|
-
* `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
|
|
102
90
|
|
|
103
91
|
* **Stabilization**: Gradient stabilization techniques.
|
|
104
92
|
* `ClipNorm`: Clips gradient L2 norm.
|
|
@@ -115,32 +103,42 @@ for epoch in range(100):
|
|
|
115
103
|
|
|
116
104
|
* **Second order**: Second order methods.
|
|
117
105
|
* `Newton`: Classic Newton's method.
|
|
118
|
-
* `
|
|
106
|
+
* `InverseFreeNewton`: Inverse-free version of Newton's method.
|
|
107
|
+
* `NewtonCG`: Matrix-free newton's method with conjugate gradient or minimal residual solvers.
|
|
108
|
+
* `TruncatedNewtonCG`: Steihaug-Toint Trust-region NewtonCG via a truncated CG solver.
|
|
119
109
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
120
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning
|
|
110
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning.
|
|
121
111
|
* `HigherOrderNewton`: Higher order Newton's method with trust region.
|
|
122
112
|
|
|
123
113
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
124
114
|
* `LBFGS`: Limited-memory BFGS.
|
|
125
115
|
* `LSR1`: Limited-memory SR1.
|
|
126
116
|
* `OnlineLBFGS`: Online LBFGS.
|
|
127
|
-
* `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `
|
|
117
|
+
* `BFGS`, `DFP`, `ICUM`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `NewSSM`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`, `ShorR`: Full-matrix quasi-newton methods.
|
|
118
|
+
* `DiagonalBFGS`, `DiagonalSR1`, `DiagonalQuasiCauchi`, `DiagonalWeightedQuasiCauchi`, `DNRTR`, `NewDQN`: Diagonal quasi-newton methods.
|
|
128
119
|
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
|
|
129
120
|
|
|
121
|
+
* **Trust Region** Trust region can work with exact hessian or any of the quasi-newton methods (L-BFGS support is WIP)
|
|
122
|
+
* `TrustCG`: Trust-region, uses a Steihaug-Toint truncated CG solver.
|
|
123
|
+
* `CubicRegularization`: Cubic regularization, works better with exact hessian.
|
|
124
|
+
|
|
130
125
|
* **Line Search**:
|
|
131
126
|
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
132
127
|
* `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
|
|
133
128
|
* `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
|
|
134
|
-
* `TrustRegion`: First order trust region method.
|
|
135
129
|
|
|
136
130
|
* **Learning Rate**:
|
|
137
131
|
* `LR`: Controls learning rate and adds support for LR schedulers.
|
|
138
|
-
* `PolyakStepSize`: Polyak's method.
|
|
139
|
-
* `
|
|
132
|
+
* `PolyakStepSize`: Polyak's subgradient method.
|
|
133
|
+
* `BarzilaiBorwein`: Barzilai-Borwein step-size.
|
|
134
|
+
* `Warmup`, `WarmupNormCLip`: Learning rate warmup.
|
|
140
135
|
|
|
141
136
|
* **Projections**: This can implement things like GaLore but I haven't done that yet.
|
|
142
|
-
* `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
143
|
-
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.).
|
|
137
|
+
<!-- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
138
|
+
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.). -->
|
|
139
|
+
This is WIP
|
|
140
|
+
* `To`: this casts everything to any other dtype and device for other modules, e.g. if you want better precision
|
|
141
|
+
* `ViewAsReal`: put if you have complex paramters.
|
|
144
142
|
|
|
145
143
|
* **Smoothing**: Smoothing-based optimization methods.
|
|
146
144
|
* `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
|
|
@@ -156,6 +154,8 @@ for epoch in range(100):
|
|
|
156
154
|
|
|
157
155
|
* **Experimental**: various horrible atrocities
|
|
158
156
|
|
|
157
|
+
A complete list of modules is available in the [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
|
|
158
|
+
|
|
159
159
|
## Advanced Usage
|
|
160
160
|
|
|
161
161
|
### Closure
|
|
@@ -282,6 +282,7 @@ class HeavyBall(Module):
|
|
|
282
282
|
super().__init__(defaults)
|
|
283
283
|
|
|
284
284
|
def step(self, var: Var):
|
|
285
|
+
# Var object holds all attributes used for optimization - parameters, gradient, update, etc.
|
|
285
286
|
# a module takes a Var object, modifies it or creates a new one, and returns it
|
|
286
287
|
# Var has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
287
288
|
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
@@ -313,28 +314,52 @@ class HeavyBall(Module):
|
|
|
313
314
|
return var
|
|
314
315
|
```
|
|
315
316
|
|
|
316
|
-
|
|
317
|
+
More in-depth guide will be available in the documentation in the future.
|
|
318
|
+
|
|
319
|
+
## Other stuff
|
|
317
320
|
|
|
318
|
-
|
|
319
|
-
* `LineSearch` for line searches
|
|
320
|
-
* `Projection` for projections like GaLore or into fourier domain.
|
|
321
|
-
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
322
|
-
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
321
|
+
There are also wrappers providing `torch.optim.Optimizer` interface for various other libraries. When using those, make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
323
322
|
|
|
324
|
-
|
|
323
|
+
---
|
|
325
324
|
|
|
326
|
-
|
|
325
|
+
### Scipy
|
|
327
326
|
|
|
328
|
-
|
|
327
|
+
#### torchzero.optim.wrappers.scipy.ScipyMinimize
|
|
329
328
|
|
|
330
|
-
|
|
329
|
+
A wrapper for `scipy.optimize.minimize` with gradients and hessians supplied by pytorch autograd. Scipy provides implementations of the following methods: `'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'`.
|
|
331
330
|
|
|
332
|
-
|
|
331
|
+
#### torchzero.optim.wrappers.scipy.ScipyDE, ScipyDualAnnealing, ScipySHGO, ScipyDIRECT, ScipyBrute
|
|
333
332
|
|
|
334
|
-
|
|
333
|
+
Equivalent wrappers for other derivative free solvers available in `scipy.optimize`
|
|
334
|
+
|
|
335
|
+
---
|
|
336
|
+
|
|
337
|
+
### NLOpt
|
|
338
|
+
|
|
339
|
+
#### torchzero.optim.wrappers.nlopt.NLOptWrapper
|
|
335
340
|
|
|
336
|
-
|
|
341
|
+
A wrapper for [NLOpt](https://github.com/stevengj/nlopt) with gradients supplied by pytorch autograd. NLOpt is another popular library with many gradient based and gradient free [algorithms](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/)
|
|
342
|
+
|
|
343
|
+
---
|
|
344
|
+
|
|
345
|
+
### Nevergrad
|
|
346
|
+
|
|
347
|
+
#### torchzero.optim.wrappers.nevergrad.NevergradWrapper
|
|
348
|
+
|
|
349
|
+
A wrapper for [nevergrad](https://facebookresearch.github.io/nevergrad/) which has a huge library of gradient free [algorithms](https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers)
|
|
350
|
+
|
|
351
|
+
---
|
|
352
|
+
|
|
353
|
+
### fast-cma-es
|
|
354
|
+
|
|
355
|
+
#### torchzero.optim.wrappers.fcmaes.FcmaesWrapper
|
|
356
|
+
|
|
357
|
+
A wrapper for [fast-cma-es](https://github.com/dietmarwo/fast-cma-es), which implements various gradient free algorithms. Notably it includes [BITEOPT](https://github.com/avaneev/biteopt) which seems to have very good performance in benchmarks.
|
|
358
|
+
|
|
359
|
+
# License
|
|
360
|
+
|
|
361
|
+
This project is licensed under the MIT License
|
|
337
362
|
|
|
338
|
-
|
|
363
|
+
# Project Links
|
|
339
364
|
|
|
340
|
-
|
|
365
|
+
The documentation is available at <https://torchzero.readthedocs.io/en/latest/>
|
|
@@ -6,10 +6,10 @@
|
|
|
6
6
|
# -- Project information -----------------------------------------------------
|
|
7
7
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
8
8
|
import sys, os
|
|
9
|
-
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
9
|
+
#sys.path.insert(0, os.path.abspath('.../src'))
|
|
10
10
|
|
|
11
11
|
project = 'torchzero'
|
|
12
|
-
copyright = '
|
|
12
|
+
copyright = '2025, Ivan Nikishev'
|
|
13
13
|
author = 'Ivan Nikishev'
|
|
14
14
|
|
|
15
15
|
# -- General configuration ---------------------------------------------------
|
|
@@ -24,10 +24,12 @@ extensions = [
|
|
|
24
24
|
'sphinx.ext.githubpages',
|
|
25
25
|
'sphinx.ext.napoleon',
|
|
26
26
|
'autoapi.extension',
|
|
27
|
+
"myst_nb",
|
|
28
|
+
|
|
27
29
|
# 'sphinx_rtd_theme',
|
|
28
30
|
]
|
|
29
31
|
autosummary_generate = True
|
|
30
|
-
autoapi_dirs = ['../../
|
|
32
|
+
autoapi_dirs = ['../../torchzero']
|
|
31
33
|
autoapi_type = "python"
|
|
32
34
|
# autoapi_ignore = ["*/tensorlist.py"]
|
|
33
35
|
|
|
@@ -48,7 +50,7 @@ exclude_patterns = []
|
|
|
48
50
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
|
49
51
|
|
|
50
52
|
#html_theme = 'alabaster'
|
|
51
|
-
html_theme = '
|
|
53
|
+
html_theme = 'sphinx_rtd_theme'
|
|
52
54
|
html_static_path = ['_static']
|
|
53
55
|
|
|
54
56
|
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
class MyModule:
|
|
2
|
+
"""[One-line summary of the class].
|
|
3
|
+
|
|
4
|
+
[A more detailed description of the class, explaining its purpose, how it
|
|
5
|
+
works, and its typical use cases. You can use multiple paragraphs.]
|
|
6
|
+
|
|
7
|
+
.. note::
|
|
8
|
+
[Optional: Add important notes, warnings, or usage guidelines here.
|
|
9
|
+
For example, you could mention if a closure is required, discuss
|
|
10
|
+
stability, or highlight performance characteristics. Use the `.. note::`
|
|
11
|
+
directive to make it stand out in the documentation.]
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
param1 (type, optional):
|
|
15
|
+
[Description of the first parameter. Use :code:`backticks` for
|
|
16
|
+
inline code like variable names or specific values like ``"autograd"``.
|
|
17
|
+
Explain what the parameter does.] Defaults to [value].
|
|
18
|
+
param2 (type):
|
|
19
|
+
[Description of a mandatory parameter (no "optional" or "Defaults to").]
|
|
20
|
+
**kwargs:
|
|
21
|
+
[If you accept keyword arguments, describe what they are used for.]
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
[A title or short sentence describing the first example]:
|
|
25
|
+
|
|
26
|
+
.. code-block:: python
|
|
27
|
+
|
|
28
|
+
opt = tz.Modular(
|
|
29
|
+
model.parameters(),
|
|
30
|
+
...
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
[A title or short sentence for a second, different example]:
|
|
34
|
+
|
|
35
|
+
.. code-block:: python
|
|
36
|
+
|
|
37
|
+
opt = tz.Modular(
|
|
38
|
+
model.parameters(),
|
|
39
|
+
...
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
References:
|
|
43
|
+
- [Optional: A citation for a relevant paper, book, or algorithm.]
|
|
44
|
+
- [Optional: A link to a blog post or website with more information.]
|
|
45
|
+
|
|
46
|
+
"""
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# NEW VERSION TUTORIAL FOR MYSELF
|
|
2
|
-
# STEP 1 - COMMIT NEW CHANGES
|
|
2
|
+
# STEP 1 - COMMIT NEW CHANGES AND PUSH THEM
|
|
3
3
|
# STEP 2 - BUMP VERSION AND COMMIT IT (DONT PUSH!!!!)
|
|
4
4
|
# STEP 3 - CREATE TAG WITH THAT VERSION
|
|
5
5
|
# STEP 4 - PUSH (SYNC) CHANGES
|
|
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
|
|
|
13
13
|
name = "torchzero"
|
|
14
14
|
description = "Modular optimization library for PyTorch."
|
|
15
15
|
|
|
16
|
-
version = "0.3.
|
|
16
|
+
version = "0.3.11"
|
|
17
17
|
dependencies = [
|
|
18
18
|
"torch",
|
|
19
19
|
"numpy",
|
|
@@ -96,8 +96,7 @@ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, s
|
|
|
96
96
|
|
|
97
97
|
@pytest.mark.parametrize('amsgrad', [True, False])
|
|
98
98
|
def test_adam(amsgrad):
|
|
99
|
-
|
|
100
|
-
# pytorch applies debiasing separately so it is applied before epsilo
|
|
99
|
+
torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
|
|
101
100
|
tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
|
|
102
101
|
tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
|
|
103
102
|
tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
|
|
@@ -133,7 +132,7 @@ def test_adam(amsgrad):
|
|
|
133
132
|
tz.m.Debias2(beta=0.999),
|
|
134
133
|
tz.m.Add(1e-8)]
|
|
135
134
|
))
|
|
136
|
-
tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
135
|
+
tz_fns = (torch_fn, tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
|
|
137
136
|
|
|
138
137
|
_assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
|
|
139
138
|
for fn in tz_fns:
|