torchzero 0.1.7__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.1.7 → torchzero-0.3.1}/LICENSE +0 -0
- torchzero-0.3.1/PKG-INFO +379 -0
- torchzero-0.3.1/README.md +340 -0
- torchzero-0.3.1/docs/source/conf.py +57 -0
- {torchzero-0.1.7 → torchzero-0.3.1}/pyproject.toml +14 -4
- torchzero-0.3.1/tests/test_identical.py +230 -0
- torchzero-0.3.1/tests/test_module.py +50 -0
- torchzero-0.3.1/tests/test_opts.py +884 -0
- torchzero-0.3.1/tests/test_tensorlist.py +1787 -0
- torchzero-0.3.1/tests/test_utils_optimizer.py +170 -0
- torchzero-0.3.1/tests/test_vars.py +184 -0
- torchzero-0.3.1/torchzero/__init__.py +4 -0
- torchzero-0.3.1/torchzero/core/__init__.py +3 -0
- torchzero-0.3.1/torchzero/core/module.py +629 -0
- torchzero-0.3.1/torchzero/core/preconditioner.py +137 -0
- torchzero-0.3.1/torchzero/core/transform.py +252 -0
- torchzero-0.3.1/torchzero/modules/__init__.py +13 -0
- torchzero-0.3.1/torchzero/modules/clipping/__init__.py +3 -0
- torchzero-0.3.1/torchzero/modules/clipping/clipping.py +320 -0
- torchzero-0.3.1/torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero-0.3.1/torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero-0.3.1/torchzero/modules/experimental/__init__.py +14 -0
- torchzero-0.3.1/torchzero/modules/experimental/absoap.py +350 -0
- torchzero-0.3.1/torchzero/modules/experimental/adadam.py +111 -0
- torchzero-0.3.1/torchzero/modules/experimental/adamY.py +135 -0
- torchzero-0.3.1/torchzero/modules/experimental/adasoap.py +282 -0
- torchzero-0.3.1/torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero-0.3.1/torchzero/modules/experimental/curveball.py +89 -0
- torchzero-0.3.1/torchzero/modules/experimental/dsoap.py +290 -0
- torchzero-0.3.1/torchzero/modules/experimental/gradmin.py +85 -0
- torchzero-0.3.1/torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero-0.3.1/torchzero/modules/experimental/spectral.py +286 -0
- torchzero-0.3.1/torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero-0.3.1/torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero-0.3.1/torchzero/modules/functional.py +209 -0
- torchzero-0.3.1/torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero-0.3.1/torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero-0.3.1/torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero-0.3.1/torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero-0.3.1/torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero-0.3.1/torchzero/modules/line_search/__init__.py +5 -0
- torchzero-0.3.1/torchzero/modules/line_search/backtracking.py +186 -0
- torchzero-0.3.1/torchzero/modules/line_search/line_search.py +181 -0
- torchzero-0.3.1/torchzero/modules/line_search/scipy.py +37 -0
- torchzero-0.3.1/torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero-0.3.1/torchzero/modules/line_search/trust_region.py +61 -0
- torchzero-0.3.1/torchzero/modules/lr/__init__.py +2 -0
- torchzero-0.3.1/torchzero/modules/lr/lr.py +59 -0
- torchzero-0.3.1/torchzero/modules/lr/step_size.py +97 -0
- torchzero-0.3.1/torchzero/modules/momentum/__init__.py +14 -0
- torchzero-0.3.1/torchzero/modules/momentum/averaging.py +78 -0
- torchzero-0.3.1/torchzero/modules/momentum/cautious.py +181 -0
- torchzero-0.3.1/torchzero/modules/momentum/ema.py +173 -0
- torchzero-0.3.1/torchzero/modules/momentum/experimental.py +189 -0
- torchzero-0.3.1/torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero-0.3.1/torchzero/modules/momentum/momentum.py +43 -0
- torchzero-0.3.1/torchzero/modules/ops/__init__.py +103 -0
- torchzero-0.3.1/torchzero/modules/ops/accumulate.py +65 -0
- torchzero-0.3.1/torchzero/modules/ops/binary.py +240 -0
- torchzero-0.3.1/torchzero/modules/ops/debug.py +25 -0
- torchzero-0.3.1/torchzero/modules/ops/misc.py +419 -0
- torchzero-0.3.1/torchzero/modules/ops/multi.py +137 -0
- torchzero-0.3.1/torchzero/modules/ops/reduce.py +149 -0
- torchzero-0.3.1/torchzero/modules/ops/split.py +75 -0
- torchzero-0.3.1/torchzero/modules/ops/switch.py +68 -0
- torchzero-0.3.1/torchzero/modules/ops/unary.py +115 -0
- torchzero-0.3.1/torchzero/modules/ops/utility.py +112 -0
- torchzero-0.3.1/torchzero/modules/optimizers/__init__.py +18 -0
- torchzero-0.3.1/torchzero/modules/optimizers/adagrad.py +146 -0
- torchzero-0.3.1/torchzero/modules/optimizers/adam.py +112 -0
- torchzero-0.3.1/torchzero/modules/optimizers/lion.py +35 -0
- torchzero-0.3.1/torchzero/modules/optimizers/muon.py +222 -0
- torchzero-0.3.1/torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero-0.3.1/torchzero/modules/optimizers/rmsprop.py +103 -0
- torchzero-0.3.1/torchzero/modules/optimizers/rprop.py +342 -0
- torchzero-0.3.1/torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero-0.3.1/torchzero/modules/optimizers/soap.py +286 -0
- torchzero-0.3.1/torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero-0.3.1/torchzero/modules/projections/__init__.py +5 -0
- torchzero-0.3.1/torchzero/modules/projections/dct.py +73 -0
- torchzero-0.3.1/torchzero/modules/projections/fft.py +73 -0
- torchzero-0.3.1/torchzero/modules/projections/galore.py +10 -0
- torchzero-0.3.1/torchzero/modules/projections/projection.py +218 -0
- torchzero-0.3.1/torchzero/modules/projections/structural.py +151 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/__init__.py +7 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero-0.3.1/torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero-0.3.1/torchzero/modules/second_order/__init__.py +3 -0
- torchzero-0.3.1/torchzero/modules/second_order/newton.py +142 -0
- torchzero-0.3.1/torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero-0.3.1/torchzero/modules/second_order/nystrom.py +168 -0
- torchzero-0.3.1/torchzero/modules/smoothing/__init__.py +2 -0
- torchzero-0.3.1/torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero-0.1.7/src/torchzero/modules/smoothing/laplacian_smoothing.py → torchzero-0.3.1/torchzero/modules/smoothing/laplacian.py +115 -128
- torchzero-0.3.1/torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero-0.3.1/torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero-0.3.1/torchzero/modules/wrappers/__init__.py +1 -0
- torchzero-0.3.1/torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero-0.3.1/torchzero/optim/__init__.py +2 -0
- torchzero-0.3.1/torchzero/optim/utility/__init__.py +1 -0
- torchzero-0.3.1/torchzero/optim/utility/split.py +45 -0
- {torchzero-0.1.7/src → torchzero-0.3.1}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.1.7/src → torchzero-0.3.1}/torchzero/optim/wrappers/nevergrad.py +2 -28
- {torchzero-0.1.7/src → torchzero-0.3.1}/torchzero/optim/wrappers/nlopt.py +31 -16
- {torchzero-0.1.7/src → torchzero-0.3.1}/torchzero/optim/wrappers/scipy.py +79 -156
- torchzero-0.3.1/torchzero/utils/__init__.py +27 -0
- torchzero-0.3.1/torchzero/utils/compile.py +177 -0
- torchzero-0.3.1/torchzero/utils/derivatives.py +513 -0
- torchzero-0.3.1/torchzero/utils/linalg/__init__.py +5 -0
- torchzero-0.3.1/torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero-0.3.1/torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero-0.3.1/torchzero/utils/linalg/qr.py +71 -0
- torchzero-0.3.1/torchzero/utils/linalg/solve.py +168 -0
- torchzero-0.3.1/torchzero/utils/linalg/svd.py +20 -0
- torchzero-0.3.1/torchzero/utils/numberlist.py +132 -0
- torchzero-0.3.1/torchzero/utils/ops.py +10 -0
- torchzero-0.3.1/torchzero/utils/optimizer.py +284 -0
- torchzero-0.3.1/torchzero/utils/optuna_tools.py +40 -0
- torchzero-0.3.1/torchzero/utils/params.py +149 -0
- torchzero-0.3.1/torchzero/utils/python_tools.py +40 -0
- torchzero-0.3.1/torchzero/utils/tensorlist.py +1081 -0
- {torchzero-0.1.7/src → torchzero-0.3.1}/torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1/torchzero.egg-info/PKG-INFO +379 -0
- torchzero-0.3.1/torchzero.egg-info/SOURCES.txt +131 -0
- torchzero-0.3.1/torchzero.egg-info/top_level.txt +4 -0
- torchzero-0.1.7/PKG-INFO +0 -120
- torchzero-0.1.7/README.md +0 -82
- torchzero-0.1.7/src/torchzero/__init__.py +0 -4
- torchzero-0.1.7/src/torchzero/core/__init__.py +0 -13
- torchzero-0.1.7/src/torchzero/core/module.py +0 -494
- torchzero-0.1.7/src/torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero-0.1.7/src/torchzero/modules/__init__.py +0 -21
- torchzero-0.1.7/src/torchzero/modules/adaptive/__init__.py +0 -4
- torchzero-0.1.7/src/torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero-0.1.7/src/torchzero/modules/experimental/__init__.py +0 -19
- torchzero-0.1.7/src/torchzero/modules/experimental/experimental.py +0 -294
- torchzero-0.1.7/src/torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero-0.1.7/src/torchzero/modules/experimental/subspace.py +0 -259
- torchzero-0.1.7/src/torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero-0.1.7/src/torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero-0.1.7/src/torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero-0.1.7/src/torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero-0.1.7/src/torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero-0.1.7/src/torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero-0.1.7/src/torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero-0.1.7/src/torchzero/modules/line_search/__init__.py +0 -30
- torchzero-0.1.7/src/torchzero/modules/line_search/armijo.py +0 -56
- torchzero-0.1.7/src/torchzero/modules/line_search/base_ls.py +0 -139
- torchzero-0.1.7/src/torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero-0.1.7/src/torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero-0.1.7/src/torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero-0.1.7/src/torchzero/modules/meta/__init__.py +0 -12
- torchzero-0.1.7/src/torchzero/modules/meta/alternate.py +0 -65
- torchzero-0.1.7/src/torchzero/modules/meta/grafting.py +0 -195
- torchzero-0.1.7/src/torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero-0.1.7/src/torchzero/modules/meta/return_overrides.py +0 -46
- torchzero-0.1.7/src/torchzero/modules/misc/__init__.py +0 -10
- torchzero-0.1.7/src/torchzero/modules/misc/accumulate.py +0 -43
- torchzero-0.1.7/src/torchzero/modules/misc/basic.py +0 -115
- torchzero-0.1.7/src/torchzero/modules/misc/lr.py +0 -96
- torchzero-0.1.7/src/torchzero/modules/misc/multistep.py +0 -51
- torchzero-0.1.7/src/torchzero/modules/misc/on_increase.py +0 -53
- torchzero-0.1.7/src/torchzero/modules/momentum/__init__.py +0 -4
- torchzero-0.1.7/src/torchzero/modules/momentum/momentum.py +0 -106
- torchzero-0.1.7/src/torchzero/modules/operations/__init__.py +0 -29
- torchzero-0.1.7/src/torchzero/modules/operations/multi.py +0 -298
- torchzero-0.1.7/src/torchzero/modules/operations/reduction.py +0 -134
- torchzero-0.1.7/src/torchzero/modules/operations/singular.py +0 -113
- torchzero-0.1.7/src/torchzero/modules/optimizers/__init__.py +0 -10
- torchzero-0.1.7/src/torchzero/modules/optimizers/adagrad.py +0 -49
- torchzero-0.1.7/src/torchzero/modules/optimizers/adam.py +0 -118
- torchzero-0.1.7/src/torchzero/modules/optimizers/lion.py +0 -28
- torchzero-0.1.7/src/torchzero/modules/optimizers/rmsprop.py +0 -51
- torchzero-0.1.7/src/torchzero/modules/optimizers/rprop.py +0 -99
- torchzero-0.1.7/src/torchzero/modules/optimizers/sgd.py +0 -54
- torchzero-0.1.7/src/torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero-0.1.7/src/torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero-0.1.7/src/torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero-0.1.7/src/torchzero/modules/quasi_newton/__init__.py +0 -4
- torchzero-0.1.7/src/torchzero/modules/regularization/__init__.py +0 -22
- torchzero-0.1.7/src/torchzero/modules/regularization/dropout.py +0 -34
- torchzero-0.1.7/src/torchzero/modules/regularization/noise.py +0 -77
- torchzero-0.1.7/src/torchzero/modules/regularization/normalization.py +0 -328
- torchzero-0.1.7/src/torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero-0.1.7/src/torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero-0.1.7/src/torchzero/modules/scheduling/__init__.py +0 -2
- torchzero-0.1.7/src/torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero-0.1.7/src/torchzero/modules/scheduling/step_size.py +0 -80
- torchzero-0.1.7/src/torchzero/modules/second_order/__init__.py +0 -4
- torchzero-0.1.7/src/torchzero/modules/second_order/newton.py +0 -165
- torchzero-0.1.7/src/torchzero/modules/smoothing/__init__.py +0 -5
- torchzero-0.1.7/src/torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero-0.1.7/src/torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero-0.1.7/src/torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero-0.1.7/src/torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero-0.1.7/src/torchzero/optim/__init__.py +0 -10
- torchzero-0.1.7/src/torchzero/optim/experimental/__init__.py +0 -20
- torchzero-0.1.7/src/torchzero/optim/experimental/experimental.py +0 -343
- torchzero-0.1.7/src/torchzero/optim/experimental/ray_search.py +0 -83
- torchzero-0.1.7/src/torchzero/optim/first_order/__init__.py +0 -18
- torchzero-0.1.7/src/torchzero/optim/first_order/cautious.py +0 -158
- torchzero-0.1.7/src/torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero-0.1.7/src/torchzero/optim/first_order/optimizers.py +0 -570
- torchzero-0.1.7/src/torchzero/optim/modular.py +0 -132
- torchzero-0.1.7/src/torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero-0.1.7/src/torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero-0.1.7/src/torchzero/optim/second_order/__init__.py +0 -1
- torchzero-0.1.7/src/torchzero/optim/second_order/newton.py +0 -94
- torchzero-0.1.7/src/torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero-0.1.7/src/torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero-0.1.7/src/torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero-0.1.7/src/torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero-0.1.7/src/torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero-0.1.7/src/torchzero/random/__init__.py +0 -1
- torchzero-0.1.7/src/torchzero/random/random.py +0 -46
- torchzero-0.1.7/src/torchzero/tensorlist.py +0 -826
- torchzero-0.1.7/src/torchzero/utils/__init__.py +0 -0
- torchzero-0.1.7/src/torchzero/utils/compile.py +0 -39
- torchzero-0.1.7/src/torchzero/utils/derivatives.py +0 -99
- torchzero-0.1.7/src/torchzero/utils/python_tools.py +0 -25
- torchzero-0.1.7/src/torchzero.egg-info/PKG-INFO +0 -120
- torchzero-0.1.7/src/torchzero.egg-info/SOURCES.txt +0 -110
- torchzero-0.1.7/src/torchzero.egg-info/top_level.txt +0 -1
- torchzero-0.1.7/tests/test_against_reference.py +0 -152
- torchzero-0.1.7/tests/test_modules.py +0 -129
- torchzero-0.1.7/tests/test_tensorlist.py +0 -27
- {torchzero-0.1.7 → torchzero-0.3.1}/setup.cfg +0 -0
- {torchzero-0.1.7/src → torchzero-0.3.1}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.1.7/src → torchzero-0.3.1}/torchzero.egg-info/requires.txt +0 -0
|
File without changes
|
torchzero-0.3.1/PKG-INFO
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchzero
|
|
3
|
+
Version: 0.3.1
|
|
4
|
+
Summary: Modular optimization library for PyTorch.
|
|
5
|
+
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2024 inikishev
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Project-URL: Homepage, https://github.com/inikishev/torchzero
|
|
29
|
+
Project-URL: Repository, https://github.com/inikishev/torchzero
|
|
30
|
+
Project-URL: Issues, https://github.com/inikishev/torchzero/isses
|
|
31
|
+
Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
|
|
32
|
+
Requires-Python: >=3.10
|
|
33
|
+
Description-Content-Type: text/markdown
|
|
34
|
+
License-File: LICENSE
|
|
35
|
+
Requires-Dist: torch
|
|
36
|
+
Requires-Dist: numpy
|
|
37
|
+
Requires-Dist: typing_extensions
|
|
38
|
+
Dynamic: license-file
|
|
39
|
+
|
|
40
|
+
# torchzero
|
|
41
|
+
|
|
42
|
+
**Modular optimization library for PyTorch**
|
|
43
|
+
|
|
44
|
+
<!-- [](https://pypi.org/project/torchzero/)
|
|
45
|
+
[](https://opensource.org/licenses/MIT)
|
|
46
|
+
[](https://github.com/torchzero/torchzero/actions)
|
|
47
|
+
[](https://torchzero.readthedocs.io/en/latest/?badge=latest) -->
|
|
48
|
+
|
|
49
|
+
`torchzero` is a Python library providing a highly modular framework for creating and experimenting with optimization algorithms in PyTorch. It allows users to easily combine and customize various components of optimizers, such as momentum techniques, gradient clipping, line searches and more.
|
|
50
|
+
|
|
51
|
+
NOTE: torchzero is in active development, currently docs are in a state of flux and pip version is extremely outdated.
|
|
52
|
+
|
|
53
|
+
## Installation
|
|
54
|
+
|
|
55
|
+
```bash
|
|
56
|
+
pip install git+https://github.com/inikishev/torchzero
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
(please don't use pip version yet, it is very outdated)
|
|
60
|
+
|
|
61
|
+
**Dependencies:**
|
|
62
|
+
|
|
63
|
+
* Python >= 3.10
|
|
64
|
+
* `torch`
|
|
65
|
+
* `numpy`
|
|
66
|
+
* `typing_extensions`
|
|
67
|
+
|
|
68
|
+
## Core Concepts
|
|
69
|
+
|
|
70
|
+
<!-- ### Modular Design
|
|
71
|
+
|
|
72
|
+
`torchzero` is built around a few key abstractions:
|
|
73
|
+
|
|
74
|
+
* **`Module`**: The base class for all components in `torchzero`. Each `Module` implements a `step(vars)` method that processes the optimization variables.
|
|
75
|
+
* **`Modular`**: The main optimizer class that chains together a sequence of `Module`s. It orchestrates the flow of data through the modules in the order they are provided.
|
|
76
|
+
* **`Transform`**: A special type of `Module` designed for tensor transformations. These are often used for operations like applying momentum or scaling gradients.
|
|
77
|
+
* **`Preconditioner`**: A subclass of `Transform`, typically used for preconditioning gradients (e.g., Adam, RMSprop).
|
|
78
|
+
|
|
79
|
+
### `Vars` Object
|
|
80
|
+
|
|
81
|
+
The `Vars` object is a data carrier that passes essential information between modules during an optimization step. It typically holds:
|
|
82
|
+
|
|
83
|
+
* `params`: The model parameters.
|
|
84
|
+
* `grad`: Gradients of the parameters.
|
|
85
|
+
* `update`: The update to be applied to the parameters.
|
|
86
|
+
* `loss`: The current loss value.
|
|
87
|
+
* `closure`: A function to re-evaluate the model and loss (used by some line search algorithms and other modules that might need to recompute gradients or loss).
|
|
88
|
+
|
|
89
|
+
### `TensorList`
|
|
90
|
+
|
|
91
|
+
`torchzero` uses a custom `TensorList` class for efficient batched operations on lists of tensors. This allows for optimized performance when dealing with multiple parameter groups or complex update rules. -->
|
|
92
|
+
|
|
93
|
+
## Quick Start / Usage Example
|
|
94
|
+
|
|
95
|
+
Here's a basic example of how to use `torchzero`:
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
import torch
|
|
99
|
+
from torch import nn
|
|
100
|
+
import torchzero as tz
|
|
101
|
+
|
|
102
|
+
# Define a simple model
|
|
103
|
+
model = nn.Linear(10, 1)
|
|
104
|
+
criterion = nn.MSELoss()
|
|
105
|
+
inputs = torch.randn(5, 10)
|
|
106
|
+
targets = torch.randn(5, 1)
|
|
107
|
+
|
|
108
|
+
# Create an optimizer
|
|
109
|
+
# The order of modules matters:
|
|
110
|
+
# 1. SOAP: Computes the update.
|
|
111
|
+
# 2. NormalizeByEMA: stabilizes the update by normalizing to an exponential moving average of past updates.
|
|
112
|
+
# 3. WeightDecay - semi-decoupled, because it is applied after SOAP, but before LR
|
|
113
|
+
# 4. LR: Scales the computed update by the learning rate (supports LR schedulers).
|
|
114
|
+
optimizer = tz.Modular(
|
|
115
|
+
model.parameters(),
|
|
116
|
+
tz.m.SOAP(),
|
|
117
|
+
tz.m.NormalizeByEMA(max_ema_growth=1.1),
|
|
118
|
+
tz.m.WeightDecay(1e-4),
|
|
119
|
+
tz.m.LR(1e-1),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Standard training loop
|
|
123
|
+
for epoch in range(100):
|
|
124
|
+
optimizer.zero_grad()
|
|
125
|
+
output = model(inputs)
|
|
126
|
+
loss = criterion(output, targets)
|
|
127
|
+
loss.backward()
|
|
128
|
+
optimizer.step()
|
|
129
|
+
if (epoch+1) % 10 == 0: print(f"Epoch {epoch+1}, Loss: {loss.item()}")
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
## Overview of Available Modules
|
|
133
|
+
|
|
134
|
+
`torchzero` provides a rich set of pre-built modules. Here are some key categories and examples:
|
|
135
|
+
|
|
136
|
+
* **Optimizers (`torchzero/modules/optimizers/`)**: Optimization algorithms.
|
|
137
|
+
* `Adam`.
|
|
138
|
+
* `Shampoo`.
|
|
139
|
+
* `SOAP` (my current recommendation).
|
|
140
|
+
* `Muon`.
|
|
141
|
+
* `SophiaH`.
|
|
142
|
+
* `Adagrad` and `FullMatrixAdagrad`.
|
|
143
|
+
* `Lion`.
|
|
144
|
+
* `RMSprop`.
|
|
145
|
+
* `OrthoGrad`.
|
|
146
|
+
* `Rprop`.
|
|
147
|
+
|
|
148
|
+
Additionally many other optimizers can be easily defined via modules:
|
|
149
|
+
* Grams: `[tz.m.Adam(), tz.m.GradSign()]`
|
|
150
|
+
* LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
|
|
151
|
+
* Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
|
|
152
|
+
* Full matrix version of any diagonal optimizer, like Adam: `tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9))`
|
|
153
|
+
* Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
|
|
154
|
+
|
|
155
|
+
* **Clipping (`torchzero/modules/clipping/`)**: Gradient clipping techniques.
|
|
156
|
+
* `ClipNorm`: Clips gradient L2 norm.
|
|
157
|
+
* `ClipValue`: Clips gradient values element-wise.
|
|
158
|
+
* `Normalize`: Normalizes gradients to unit norm.
|
|
159
|
+
* `Centralize`: Centralizes gradients by subtracting the mean.
|
|
160
|
+
* `ClipNormByEMA`, `NormalizeByEMA`, `ClipValueByEMA`: Clipping/Normalization based on EMA of past values.
|
|
161
|
+
* `ClipNormGrowth`, `ClipValueGrowth`: Limits norm or value growth.
|
|
162
|
+
* **Gradient Approximation (`torchzero/modules/grad_approximation/`)**: Methods for approximating gradients.
|
|
163
|
+
* `FDM`: Finite Difference Method.
|
|
164
|
+
* `RandomizedFDM` (`MeZO`, `SPSA`, `RDSA`, `Gaussian smoothing`): Randomized Finite Difference Methods (also subspaces).
|
|
165
|
+
* `ForwardGradient`: Randomized gradient approximation via forward mode automatic differentiation.
|
|
166
|
+
* **Line Search (`torchzero/modules/line_search/`)**: Techniques for finding optimal step sizes.
|
|
167
|
+
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches.
|
|
168
|
+
* `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
|
|
169
|
+
* `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
|
|
170
|
+
* `TrustRegion`: First order trust region method.
|
|
171
|
+
* **Learning Rate (`torchzero/modules/lr/`)**: Learning rate control.
|
|
172
|
+
* `LR`: Applies a fixed learning rate.
|
|
173
|
+
* `PolyakStepSize`: Polyak's method.
|
|
174
|
+
* `Warmup`: Learning rate warmup.
|
|
175
|
+
* **Momentum (`torchzero/modules/momentum/`)**: Momentum-based update modifications.
|
|
176
|
+
* `NAG`: Nesterov Accelerated Gradient.
|
|
177
|
+
* `HeavyBall`: Classic momentum (Polyak's momentum).
|
|
178
|
+
* `EMA`: Exponential moving average.
|
|
179
|
+
* `Averaging` (`Medianveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
|
|
180
|
+
* `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
|
|
181
|
+
* `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
|
|
182
|
+
<!-- * `CoordinateMomentum`: Momentum via random coordinates. -->
|
|
183
|
+
* **Projections (`torchzero/modules/projections/`)**: Gradient projection techniques.
|
|
184
|
+
* `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain.
|
|
185
|
+
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods.
|
|
186
|
+
<!-- * *(Note: DCT and Galore were commented out in the `__init__.py` I read, might be experimental or moved).* -->
|
|
187
|
+
* **Quasi-Newton (`torchzero/modules/quasi_newton/`)**: Approximate second-order optimization methods.
|
|
188
|
+
* `LBFGS`: Limited-memory BFGS.
|
|
189
|
+
* `LSR1`: Limited-memory SR1.
|
|
190
|
+
* `OnlineLBFGS`: Online LBFGS.
|
|
191
|
+
<!-- * `ModularLBFGS`: A modular L-BFGS implementation (from experimental). -->
|
|
192
|
+
* `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix Quasi-Newton update formulas.
|
|
193
|
+
* Conjugate Gradient methods: `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`.
|
|
194
|
+
* **Second Order (`torchzero/modules/second_order/`)**: Second order methods.
|
|
195
|
+
* `Newton`: Classic Newton's method.
|
|
196
|
+
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
197
|
+
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
198
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning.
|
|
199
|
+
* **Smoothing (`torchzero/modules/smoothing/`)**: Techniques for smoothing the loss landscape or gradients.
|
|
200
|
+
* `LaplacianSmoothing`: Laplacian smoothing for gradients.
|
|
201
|
+
* `GaussianHomotopy`: Smoothing via randomized Gaussian homotopy.
|
|
202
|
+
* **Weight Decay (`torchzero/modules/weight_decay/`)**: Weight decay implementations.
|
|
203
|
+
* `WeightDecay`: Standard L2 or L1 weight decay.
|
|
204
|
+
<!-- * `DirectWeightDecay`: Applies weight decay directly to weights.
|
|
205
|
+
* `decay_weights_`: Functional form for decaying weights. -->
|
|
206
|
+
* **Ops (`torchzero/modules/ops/`)**: Various tensor operations and utilities.
|
|
207
|
+
* `GradientAccumulation`: easy way to add gradient accumulation.
|
|
208
|
+
* `Unary*` (e.g., `Abs`, `Sqrt`, `Sign`): Unary operations.
|
|
209
|
+
* `Binary*` (e.g., `Add`, `Mul`, `Graft`): Binary operations.
|
|
210
|
+
* `Multi*` (e.g., `ClipModules`, `LerpModules`): Operations on multiple module outputs.
|
|
211
|
+
* `Reduce*` (e.g., `Mean`, `Sum`, `WeightedMean`): Reduction operations on multiple module outputs.
|
|
212
|
+
|
|
213
|
+
* **Wrappers (`torchzero/modules/wrappers/`)**.
|
|
214
|
+
* `Wrap`: Wraps any PyTorch optimizer, allowing to use it as a module.
|
|
215
|
+
|
|
216
|
+
<!-- * **Experimental (`torchzero/modules/experimental/`)**: Experimental modules.
|
|
217
|
+
* `GradMin`: Attempts to minimize gradient norm.
|
|
218
|
+
* `ReduceOutwardLR`: Reduces learning rate for parameters with outward pointing gradients.
|
|
219
|
+
* `RandomSubspacePreconditioning`, `HistorySubspacePreconditioning`: Preconditioning techniques using random or historical subspaces. -->
|
|
220
|
+
|
|
221
|
+
## Advanced Usage
|
|
222
|
+
|
|
223
|
+
### Closure
|
|
224
|
+
|
|
225
|
+
Certain modules, particularly line searches and gradient approximations require a closure, similar to L-BFGS in PyTorch. In TorchZero closure accepts an additional `backward` argument, refer to example below:
|
|
226
|
+
|
|
227
|
+
```python
|
|
228
|
+
# basic training loop
|
|
229
|
+
for inputs, targets in dataloader:
|
|
230
|
+
|
|
231
|
+
def closure(backward=True): # make sure it is True by default
|
|
232
|
+
preds = model(inputs)
|
|
233
|
+
loss = criterion(preds, targets)
|
|
234
|
+
|
|
235
|
+
if backward:
|
|
236
|
+
optimizer.zero_grad()
|
|
237
|
+
loss.backward()
|
|
238
|
+
|
|
239
|
+
return loss
|
|
240
|
+
|
|
241
|
+
loss = optimizer.step(closure)
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
Also the closure above works with all PyTorch optimizers and most custom ones, so there is no need to rewrite the training loop.
|
|
245
|
+
|
|
246
|
+
Non-batched example (rosenbrock):
|
|
247
|
+
|
|
248
|
+
```py
|
|
249
|
+
import torchzero as tz
|
|
250
|
+
|
|
251
|
+
def rosen(x, y):
|
|
252
|
+
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
|
|
253
|
+
|
|
254
|
+
W = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
255
|
+
|
|
256
|
+
def closure(backward=True):
|
|
257
|
+
loss = rosen(*W)
|
|
258
|
+
if backward:
|
|
259
|
+
W.grad = None # same as opt.zero_grad()
|
|
260
|
+
loss.backward()
|
|
261
|
+
return loss
|
|
262
|
+
|
|
263
|
+
opt = tz.Modular([W], tz.m.NewtonCG(), tz.m.StrongWolfe())
|
|
264
|
+
for step in range(20):
|
|
265
|
+
loss = opt.step(closure)
|
|
266
|
+
print(f'{step} - {loss}')
|
|
267
|
+
```
|
|
268
|
+
|
|
269
|
+
### Low level modules
|
|
270
|
+
|
|
271
|
+
TorchZero provides a lot of low-level modules that can be used to recreate update rules, or combine existing update rules
|
|
272
|
+
in new ways. Here are some equivalent ways to make Adam in order of their involvement:
|
|
273
|
+
|
|
274
|
+
```python
|
|
275
|
+
tz.m.Adam()
|
|
276
|
+
```
|
|
277
|
+
|
|
278
|
+
```python
|
|
279
|
+
tz.m.RMSprop(0.999, debiased=True, init='zeros', inner=tz.m.EMA(0.9))
|
|
280
|
+
```
|
|
281
|
+
|
|
282
|
+
```python
|
|
283
|
+
tz.m.DivModules(
|
|
284
|
+
tz.m.EMA(0.9, debiased=True),
|
|
285
|
+
[tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
|
|
286
|
+
)
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
```python
|
|
290
|
+
tz.m.DivModules(
|
|
291
|
+
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
|
|
292
|
+
[tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
|
|
293
|
+
)
|
|
294
|
+
```
|
|
295
|
+
|
|
296
|
+
```python
|
|
297
|
+
tz.m.DivModules(
|
|
298
|
+
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
|
|
299
|
+
[
|
|
300
|
+
tz.m.Pow(2),
|
|
301
|
+
tz.m.EMA(0.999),
|
|
302
|
+
tz.m.AccumulateMaximum() if amsgrad else tz.m.Identity(),
|
|
303
|
+
tz.m.Sqrt(),
|
|
304
|
+
tz.m.Debias2(beta=0.999),
|
|
305
|
+
tz.m.Add(1e-8)]
|
|
306
|
+
)
|
|
307
|
+
```
|
|
308
|
+
|
|
309
|
+
There are practically no rules to the ordering of the modules - anything will work, even line search after line search or nested gaussian homotopy.
|
|
310
|
+
|
|
311
|
+
### Quick guide to implementing new modules
|
|
312
|
+
|
|
313
|
+
Modules are quite similar to torch.optim.Optimizer, the main difference is that everything is stored in the Vars object,
|
|
314
|
+
not in the module itself. Also both per-parameter settings and state are stored in per-parameter dictionaries. Feel free to modify the example below.
|
|
315
|
+
|
|
316
|
+
```python
|
|
317
|
+
import torch
|
|
318
|
+
from torchzero.core import Module, Vars
|
|
319
|
+
|
|
320
|
+
class HeavyBall(Module):
|
|
321
|
+
def __init__(self, momentum: float = 0.9, dampening: float = 0):
|
|
322
|
+
defaults = dict(momentum=momentum, dampening=dampening)
|
|
323
|
+
super().__init__(defaults)
|
|
324
|
+
|
|
325
|
+
def step(self, vars: Vars):
|
|
326
|
+
# a module takes a Vars object, modifies it or creates a new one, and returns it
|
|
327
|
+
# Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
328
|
+
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
329
|
+
|
|
330
|
+
params = vars.params
|
|
331
|
+
update = vars.get_update() # list of tensors
|
|
332
|
+
|
|
333
|
+
exp_avg_list = []
|
|
334
|
+
for p, u in zip(params, update):
|
|
335
|
+
state = self.state[p]
|
|
336
|
+
settings = self.settings[p]
|
|
337
|
+
momentum = settings['momentum']
|
|
338
|
+
dampening = settings['dampening']
|
|
339
|
+
|
|
340
|
+
if 'momentum_buffer' not in state:
|
|
341
|
+
state['momentum_buffer'] = torch.zeros_like(p)
|
|
342
|
+
|
|
343
|
+
buf = state['momentum_buffer']
|
|
344
|
+
u *= 1 - dampening
|
|
345
|
+
|
|
346
|
+
buf.mul_(momentum).add_(u)
|
|
347
|
+
|
|
348
|
+
# clone because further modules might modify exp_avg in-place
|
|
349
|
+
# and it is part of self.state
|
|
350
|
+
exp_avg_list.append(buf.clone())
|
|
351
|
+
|
|
352
|
+
# set new update to vars
|
|
353
|
+
vars.update = exp_avg_list
|
|
354
|
+
return vars
|
|
355
|
+
```
|
|
356
|
+
|
|
357
|
+
There are a some specialized base modules.
|
|
358
|
+
|
|
359
|
+
* `GradApproximator` for gradient approximations
|
|
360
|
+
* `LineSearch` for line searches
|
|
361
|
+
* `Preconditioner` for gradient preconditioners
|
|
362
|
+
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
363
|
+
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
364
|
+
|
|
365
|
+
## License
|
|
366
|
+
|
|
367
|
+
This project is licensed under the MIT License
|
|
368
|
+
|
|
369
|
+
## Project Links
|
|
370
|
+
|
|
371
|
+
TODO (there are docs but from very old version)
|
|
372
|
+
<!-- * **Homepage**: `https://torchzero.github.io/torchzero/` (Placeholder - update if available)
|
|
373
|
+
* **Repository**: `https://github.com/torchzero/torchzero` (Assuming this is the correct path) -->
|
|
374
|
+
|
|
375
|
+
## Other stuff
|
|
376
|
+
|
|
377
|
+
There are also wrappers providing `torch.optim.Optimizer` interface for for `scipy.optimize`, NLOpt and Nevergrad.
|
|
378
|
+
|
|
379
|
+
They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
|