torchzero 0.3.4__tar.gz → 0.3.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.3.4 → torchzero-0.3.6}/PKG-INFO +104 -104
- {torchzero-0.3.4 → torchzero-0.3.6}/README.md +103 -103
- {torchzero-0.3.4 → torchzero-0.3.6}/pyproject.toml +2 -2
- {torchzero-0.3.4 → torchzero-0.3.6}/tests/test_tensorlist.py +17 -17
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero.egg-info/PKG-INFO +104 -104
- {torchzero-0.3.4 → torchzero-0.3.6}/LICENSE +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/docs/source/conf.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/setup.cfg +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/tests/test_identical.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/tests/test_module.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/tests/test_opts.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/tests/test_vars.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/core/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/core/module.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/core/preconditioner.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/core/transform.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/clipping/ema_clipping.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/clipping/growth_clipping.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/absoap.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/adadam.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/adamY.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/adasoap.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/algebraic_newton.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/curveball.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/dsoap.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/gradmin.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/spectral.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/subspace_preconditioners.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/experimental/tropical_newton.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/functional.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/grad_approximation/rfdm.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/line_search/scipy.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/line_search/strong_wolfe.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/line_search/trust_region.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/lr/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/lr/lr.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/lr/step_size.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/momentum/averaging.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/momentum/ema.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/momentum/experimental.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/momentum/matrix_momentum.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/momentum/momentum.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/accumulate.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/debug.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/misc.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/split.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/switch.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/adagrad.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/adam.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/lion.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/muon.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/orthograd.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/rmsprop.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/rprop.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/shampoo.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/soap.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/optimizers/sophia_h.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/projections/dct.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/projections/fft.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/projections/projection.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/projections/structural.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/cg.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/lbfgs.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/lsr1.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/olbfgs.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/quasi_newton/quasi_newton.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/second_order/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/second_order/newton.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/second_order/newton_cg.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/second_order/nystrom.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/smoothing/gaussian.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/weight_decay/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/weight_decay/weight_decay.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/optim/wrappers/scipy.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/compile.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/linalg/__init__.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/linalg/matrix_funcs.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/linalg/orthogonalize.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/linalg/qr.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/linalg/solve.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/linalg/svd.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/ops.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/params.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/python_tools.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero.egg-info/SOURCES.txt +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.3.4 → torchzero-0.3.6}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.6
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -37,26 +37,27 @@ Requires-Dist: numpy
|
|
|
37
37
|
Requires-Dist: typing_extensions
|
|
38
38
|
Dynamic: license-file
|
|
39
39
|
|
|
40
|
+

|
|
41
|
+
|
|
40
42
|
# torchzero
|
|
41
43
|
|
|
42
44
|
**Modular optimization library for PyTorch**
|
|
43
45
|
|
|
44
|
-
|
|
45
|
-
[](https://opensource.org/licenses/MIT)
|
|
46
|
-
[](https://github.com/torchzero/torchzero/actions)
|
|
47
|
-
[](https://torchzero.readthedocs.io/en/latest/?badge=latest) -->
|
|
48
|
-
|
|
49
|
-
`torchzero` is a Python library providing a highly modular framework for creating and experimenting with optimization algorithms in PyTorch. It allows users to easily combine and customize various components of optimizers, such as momentum techniques, gradient clipping, line searches and more.
|
|
46
|
+
`torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
|
|
50
47
|
|
|
51
|
-
NOTE: torchzero is in active development, currently docs are in a state of flux
|
|
48
|
+
NOTE: torchzero is in active development, currently docs are in a state of flux.
|
|
52
49
|
|
|
53
50
|
## Installation
|
|
54
51
|
|
|
55
52
|
```bash
|
|
56
|
-
pip install
|
|
53
|
+
pip install torchzero
|
|
57
54
|
```
|
|
58
55
|
|
|
59
|
-
|
|
56
|
+
pip version is always the latest one. Or install from this repo
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
pip install git+https://github.com/inikishev/torchzero
|
|
60
|
+
```
|
|
60
61
|
|
|
61
62
|
**Dependencies:**
|
|
62
63
|
|
|
@@ -65,34 +66,9 @@ pip install git+https://github.com/inikishev/torchzero
|
|
|
65
66
|
* `numpy`
|
|
66
67
|
* `typing_extensions`
|
|
67
68
|
|
|
68
|
-
## Core Concepts
|
|
69
|
-
|
|
70
|
-
<!-- ### Modular Design
|
|
71
|
-
|
|
72
|
-
`torchzero` is built around a few key abstractions:
|
|
73
|
-
|
|
74
|
-
* **`Module`**: The base class for all components in `torchzero`. Each `Module` implements a `step(vars)` method that processes the optimization variables.
|
|
75
|
-
* **`Modular`**: The main optimizer class that chains together a sequence of `Module`s. It orchestrates the flow of data through the modules in the order they are provided.
|
|
76
|
-
* **`Transform`**: A special type of `Module` designed for tensor transformations. These are often used for operations like applying momentum or scaling gradients.
|
|
77
|
-
* **`Preconditioner`**: A subclass of `Transform`, typically used for preconditioning gradients (e.g., Adam, RMSprop).
|
|
78
|
-
|
|
79
|
-
### `Vars` Object
|
|
80
|
-
|
|
81
|
-
The `Vars` object is a data carrier that passes essential information between modules during an optimization step. It typically holds:
|
|
82
|
-
|
|
83
|
-
* `params`: The model parameters.
|
|
84
|
-
* `grad`: Gradients of the parameters.
|
|
85
|
-
* `update`: The update to be applied to the parameters.
|
|
86
|
-
* `loss`: The current loss value.
|
|
87
|
-
* `closure`: A function to re-evaluate the model and loss (used by some line search algorithms and other modules that might need to recompute gradients or loss).
|
|
88
|
-
|
|
89
|
-
### `TensorList`
|
|
90
|
-
|
|
91
|
-
`torchzero` uses a custom `TensorList` class for efficient batched operations on lists of tensors. This allows for optimized performance when dealing with multiple parameter groups or complex update rules. -->
|
|
92
|
-
|
|
93
69
|
## Quick Start / Usage Example
|
|
94
70
|
|
|
95
|
-
|
|
71
|
+
Basic example:
|
|
96
72
|
|
|
97
73
|
```python
|
|
98
74
|
import torch
|
|
@@ -107,13 +83,16 @@ targets = torch.randn(5, 1)
|
|
|
107
83
|
|
|
108
84
|
# Create an optimizer
|
|
109
85
|
# The order of modules matters:
|
|
110
|
-
# 1.
|
|
111
|
-
# 2.
|
|
112
|
-
# 3.
|
|
113
|
-
#
|
|
86
|
+
# 1. ClipValue: clips gradients to (-10, 10) range.
|
|
87
|
+
# 2. Adam: applies Adam update rule to clipped gradients.
|
|
88
|
+
# 3. NormalizeByEMA: stabilizes the update by normalizing it to an exponential
|
|
89
|
+
# moving average of past updates.
|
|
90
|
+
# 4. WeightDecay - decoupled weight decay (can also move after LR to fully decouple)
|
|
91
|
+
# 5. LR: Scales the computed update by the learning rate (supports LR schedulers).
|
|
114
92
|
optimizer = tz.Modular(
|
|
115
93
|
model.parameters(),
|
|
116
|
-
tz.m.
|
|
94
|
+
tz.m.ClipValue(10),
|
|
95
|
+
tz.m.Adam(),
|
|
117
96
|
tz.m.NormalizeByEMA(max_ema_growth=1.1),
|
|
118
97
|
tz.m.WeightDecay(1e-4),
|
|
119
98
|
tz.m.LR(1e-1),
|
|
@@ -131,9 +110,9 @@ for epoch in range(100):
|
|
|
131
110
|
|
|
132
111
|
## Overview of Available Modules
|
|
133
112
|
|
|
134
|
-
`torchzero` provides a
|
|
113
|
+
`torchzero` provides a huge number of various modules:
|
|
135
114
|
|
|
136
|
-
* **Optimizers
|
|
115
|
+
* **Optimizers**: Optimization algorithms.
|
|
137
116
|
* `Adam`.
|
|
138
117
|
* `Shampoo`.
|
|
139
118
|
* `SOAP` (my current recommendation).
|
|
@@ -152,77 +131,74 @@ for epoch in range(100):
|
|
|
152
131
|
* Full matrix version of any diagonal optimizer, like Adam: `tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9))`
|
|
153
132
|
* Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
|
|
154
133
|
|
|
155
|
-
* **
|
|
134
|
+
* **Momentum**:
|
|
135
|
+
* `NAG`: Nesterov Accelerated Gradient.
|
|
136
|
+
* `HeavyBall`: Classic momentum (Polyak's momentum).
|
|
137
|
+
* `EMA`: Exponential moving average.
|
|
138
|
+
* `Averaging` (`Medianveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
|
|
139
|
+
* `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
|
|
140
|
+
* `MatrixMomentum`, `AdaptiveMatrixMomentum`: Second order momentum.
|
|
141
|
+
|
|
142
|
+
* **Stabilization**: Gradient stabilization techniques.
|
|
156
143
|
* `ClipNorm`: Clips gradient L2 norm.
|
|
157
144
|
* `ClipValue`: Clips gradient values element-wise.
|
|
158
145
|
* `Normalize`: Normalizes gradients to unit norm.
|
|
159
146
|
* `Centralize`: Centralizes gradients by subtracting the mean.
|
|
160
147
|
* `ClipNormByEMA`, `NormalizeByEMA`, `ClipValueByEMA`: Clipping/Normalization based on EMA of past values.
|
|
161
148
|
* `ClipNormGrowth`, `ClipValueGrowth`: Limits norm or value growth.
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
* `
|
|
149
|
+
|
|
150
|
+
* **Gradient approximations**: Methods for approximating gradients.
|
|
151
|
+
* `FDM`: Finite difference method.
|
|
152
|
+
* `RandomizedFDM` (`MeZO`, `SPSA`, `RDSA`, `Gaussian smoothing`): Randomized finite difference methods (also subspaces).
|
|
165
153
|
* `ForwardGradient`: Randomized gradient approximation via forward mode automatic differentiation.
|
|
166
|
-
|
|
167
|
-
|
|
154
|
+
|
|
155
|
+
* **Second order**: Second order methods.
|
|
156
|
+
* `Newton`: Classic Newton's method.
|
|
157
|
+
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
158
|
+
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning (my current recommendation).
|
|
160
|
+
|
|
161
|
+
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
|
+
* `LBFGS`: Limited-memory BFGS.
|
|
163
|
+
* `LSR1`: Limited-memory SR1.
|
|
164
|
+
* `OnlineLBFGS`: Online LBFGS.
|
|
165
|
+
* `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix quasi-newton methods.
|
|
166
|
+
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
|
|
167
|
+
|
|
168
|
+
* **Line Search**:
|
|
169
|
+
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
168
170
|
* `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
|
|
169
171
|
* `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
|
|
170
172
|
* `TrustRegion`: First order trust region method.
|
|
171
|
-
|
|
172
|
-
|
|
173
|
+
|
|
174
|
+
* **Learning Rate**:
|
|
175
|
+
* `LR`: Controls learning rate and adds support for LR schedulers.
|
|
173
176
|
* `PolyakStepSize`: Polyak's method.
|
|
174
177
|
* `Warmup`: Learning rate warmup.
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
* `
|
|
178
|
-
* `
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
* `
|
|
182
|
-
<!-- * `CoordinateMomentum`: Momentum via random coordinates. -->
|
|
183
|
-
* **Projections (`torchzero/modules/projections/`)**: Gradient projection techniques.
|
|
184
|
-
* `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain.
|
|
185
|
-
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods.
|
|
186
|
-
<!-- * *(Note: DCT and Galore were commented out in the `__init__.py` I read, might be experimental or moved).* -->
|
|
187
|
-
* **Quasi-Newton (`torchzero/modules/quasi_newton/`)**: Approximate second-order optimization methods.
|
|
188
|
-
* `LBFGS`: Limited-memory BFGS.
|
|
189
|
-
* `LSR1`: Limited-memory SR1.
|
|
190
|
-
* `OnlineLBFGS`: Online LBFGS.
|
|
191
|
-
<!-- * `ModularLBFGS`: A modular L-BFGS implementation (from experimental). -->
|
|
192
|
-
* `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix Quasi-Newton update formulas.
|
|
193
|
-
* Conjugate Gradient methods: `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`.
|
|
194
|
-
* **Second Order (`torchzero/modules/second_order/`)**: Second order methods.
|
|
195
|
-
* `Newton`: Classic Newton's method.
|
|
196
|
-
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
197
|
-
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
198
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning.
|
|
199
|
-
* **Smoothing (`torchzero/modules/smoothing/`)**: Techniques for smoothing the loss landscape or gradients.
|
|
200
|
-
* `LaplacianSmoothing`: Laplacian smoothing for gradients.
|
|
178
|
+
|
|
179
|
+
* **Projections**: This can implement things like GaLore but I haven't done that yet.
|
|
180
|
+
* `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
181
|
+
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.).
|
|
182
|
+
|
|
183
|
+
* **Smoothing**: Smoothing-based optimization methods.
|
|
184
|
+
* `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
|
|
201
185
|
* `GaussianHomotopy`: Smoothing via randomized Gaussian homotopy.
|
|
202
|
-
|
|
186
|
+
|
|
187
|
+
* **Weight Decay**:.
|
|
203
188
|
* `WeightDecay`: Standard L2 or L1 weight decay.
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
* `Unary*` (e.g., `Abs`, `Sqrt`, `Sign`): Unary operations.
|
|
209
|
-
* `Binary*` (e.g., `Add`, `Mul`, `Graft`): Binary operations.
|
|
210
|
-
* `Multi*` (e.g., `ClipModules`, `LerpModules`): Operations on multiple module outputs.
|
|
211
|
-
* `Reduce*` (e.g., `Mean`, `Sum`, `WeightedMean`): Reduction operations on multiple module outputs.
|
|
212
|
-
|
|
213
|
-
* **Wrappers (`torchzero/modules/wrappers/`)**.
|
|
189
|
+
|
|
190
|
+
* **Ops**: This has low level operations, also stuff like grafting and gradient accumulation.
|
|
191
|
+
|
|
192
|
+
* **Wrappers**.
|
|
214
193
|
* `Wrap`: Wraps any PyTorch optimizer, allowing to use it as a module.
|
|
215
194
|
|
|
216
|
-
|
|
217
|
-
* `GradMin`: Attempts to minimize gradient norm.
|
|
218
|
-
* `ReduceOutwardLR`: Reduces learning rate for parameters with outward pointing gradients.
|
|
219
|
-
* `RandomSubspacePreconditioning`, `HistorySubspacePreconditioning`: Preconditioning techniques using random or historical subspaces. -->
|
|
195
|
+
* **Experimental**: various horrible atrocities
|
|
220
196
|
|
|
221
197
|
## Advanced Usage
|
|
222
198
|
|
|
223
199
|
### Closure
|
|
224
200
|
|
|
225
|
-
Certain modules, particularly line searches and gradient approximations require a closure, similar to L-BFGS in PyTorch.
|
|
201
|
+
Certain modules, particularly line searches and gradient approximations require a closure, similar to L-BFGS in PyTorch. Also some modules require closure to accept an additional `backward` argument, refer to example below:
|
|
226
202
|
|
|
227
203
|
```python
|
|
228
204
|
# basic training loop
|
|
@@ -232,7 +208,7 @@ for inputs, targets in dataloader:
|
|
|
232
208
|
preds = model(inputs)
|
|
233
209
|
loss = criterion(preds, targets)
|
|
234
210
|
|
|
235
|
-
if backward:
|
|
211
|
+
if backward: # gradient approximations always call with backward=False.
|
|
236
212
|
optimizer.zero_grad()
|
|
237
213
|
loss.backward()
|
|
238
214
|
|
|
@@ -241,7 +217,7 @@ for inputs, targets in dataloader:
|
|
|
241
217
|
loss = optimizer.step(closure)
|
|
242
218
|
```
|
|
243
219
|
|
|
244
|
-
|
|
220
|
+
The code above will also work with any other optimizer because all PyTorch optimizers and most custom ones support closure, so there is no need to rewrite training loop.
|
|
245
221
|
|
|
246
222
|
Non-batched example (rosenbrock):
|
|
247
223
|
|
|
@@ -266,9 +242,31 @@ for step in range(20):
|
|
|
266
242
|
print(f'{step} - {loss}')
|
|
267
243
|
```
|
|
268
244
|
|
|
245
|
+
### Module combinations
|
|
246
|
+
|
|
247
|
+
There are practically no rules to the ordering of the modules - anything will work. For example any method can be made zeroth order by putting it after some gradient approximation module such as GaussianSmoothing:
|
|
248
|
+
|
|
249
|
+
```python
|
|
250
|
+
opt = tz.Modular(
|
|
251
|
+
bench.parameters(),
|
|
252
|
+
tz.m.GaussianSmoothing(h=0.01, n_samples=10),
|
|
253
|
+
tz.m.NewtonCG(hvp_method='forward'),
|
|
254
|
+
tz.m.AdaptiveBacktracking(),
|
|
255
|
+
)
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
GaussianSmoothing actually creates a new **closure** which approximates the gradient. To NewtonCG this closure is just like
|
|
259
|
+
any other closure, so it works seamlessly.
|
|
260
|
+
|
|
261
|
+
Any module can be projected (this is how it will work once I implement GaLore, but I haven't done that yet):
|
|
262
|
+
|
|
263
|
+
```python
|
|
264
|
+
tz.m.GaLore([tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(1e-2)])
|
|
265
|
+
```
|
|
266
|
+
|
|
269
267
|
### Low level modules
|
|
270
268
|
|
|
271
|
-
|
|
269
|
+
torchzero provides a lot of low-level modules that can be used to recreate update rules, or combine existing update rules
|
|
272
270
|
in new ways. Here are some equivalent ways to make Adam in order of their involvement:
|
|
273
271
|
|
|
274
272
|
```python
|
|
@@ -276,20 +274,21 @@ tz.m.Adam()
|
|
|
276
274
|
```
|
|
277
275
|
|
|
278
276
|
```python
|
|
277
|
+
# Adam is debiased RMSprop applied to EMA
|
|
279
278
|
tz.m.RMSprop(0.999, debiased=True, init='zeros', inner=tz.m.EMA(0.9))
|
|
280
279
|
```
|
|
281
280
|
|
|
282
281
|
```python
|
|
283
282
|
tz.m.DivModules(
|
|
284
283
|
tz.m.EMA(0.9, debiased=True),
|
|
285
|
-
[tz.m.SqrtEMASquared(0.999, debiased=True
|
|
284
|
+
[tz.m.SqrtEMASquared(0.999, debiased=True), tz.m.Add(1e-8)]
|
|
286
285
|
)
|
|
287
286
|
```
|
|
288
287
|
|
|
289
288
|
```python
|
|
290
289
|
tz.m.DivModules(
|
|
291
290
|
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
|
|
292
|
-
[tz.m.EMASquared(0.999
|
|
291
|
+
[tz.m.EMASquared(0.999), tz.m.Sqrt(), tz.m.Add(1e-8)]
|
|
293
292
|
)
|
|
294
293
|
```
|
|
295
294
|
|
|
@@ -306,8 +305,6 @@ tz.m.DivModules(
|
|
|
306
305
|
)
|
|
307
306
|
```
|
|
308
307
|
|
|
309
|
-
There are practically no rules to the ordering of the modules - anything will work, even line search after line search or nested gaussian homotopy.
|
|
310
|
-
|
|
311
308
|
### Quick guide to implementing new modules
|
|
312
309
|
|
|
313
310
|
Modules are quite similar to torch.optim.Optimizer, the main difference is that everything is stored in the Vars object,
|
|
@@ -354,14 +351,17 @@ class HeavyBall(Module):
|
|
|
354
351
|
return vars
|
|
355
352
|
```
|
|
356
353
|
|
|
357
|
-
There are a some specialized base modules.
|
|
354
|
+
There are a some specialized base modules that make it much easier to implement some specific things.
|
|
358
355
|
|
|
359
356
|
* `GradApproximator` for gradient approximations
|
|
360
357
|
* `LineSearch` for line searches
|
|
361
|
-
* `Preconditioner` for
|
|
358
|
+
* `Preconditioner` for preconditioners
|
|
359
|
+
* `Projection` for projections like GaLore or into fourier domain.
|
|
362
360
|
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
363
361
|
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
364
362
|
|
|
363
|
+
The documentation on how to actually use them is to write itself in the near future.
|
|
364
|
+
|
|
365
365
|
## License
|
|
366
366
|
|
|
367
367
|
This project is licensed under the MIT License
|
|
@@ -369,11 +369,11 @@ This project is licensed under the MIT License
|
|
|
369
369
|
## Project Links
|
|
370
370
|
|
|
371
371
|
TODO (there are docs but from very old version)
|
|
372
|
-
<!-- * **Homepage**: `https://torchzero.github.io/torchzero/` (Placeholder - update if available)
|
|
373
|
-
* **Repository**: `https://github.com/torchzero/torchzero` (Assuming this is the correct path) -->
|
|
374
372
|
|
|
375
373
|
## Other stuff
|
|
376
374
|
|
|
377
375
|
There are also wrappers providing `torch.optim.Optimizer` interface for for `scipy.optimize`, NLOpt and Nevergrad.
|
|
378
376
|
|
|
379
377
|
They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
378
|
+
|
|
379
|
+
Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
|