torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,404 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: torchzero
|
|
3
|
-
Version: 0.3.11
|
|
4
|
-
Summary: Modular optimization library for PyTorch.
|
|
5
|
-
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
|
-
License: MIT License
|
|
7
|
-
|
|
8
|
-
Copyright (c) 2024 inikishev
|
|
9
|
-
|
|
10
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
-
in the Software without restriction, including without limitation the rights
|
|
13
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
-
furnished to do so, subject to the following conditions:
|
|
16
|
-
|
|
17
|
-
The above copyright notice and this permission notice shall be included in all
|
|
18
|
-
copies or substantial portions of the Software.
|
|
19
|
-
|
|
20
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
-
SOFTWARE.
|
|
27
|
-
|
|
28
|
-
Project-URL: Homepage, https://github.com/inikishev/torchzero
|
|
29
|
-
Project-URL: Repository, https://github.com/inikishev/torchzero
|
|
30
|
-
Project-URL: Issues, https://github.com/inikishev/torchzero/isses
|
|
31
|
-
Keywords: optimization,optimizers,torch,neural networks,zeroth order,second order
|
|
32
|
-
Requires-Python: >=3.10
|
|
33
|
-
Description-Content-Type: text/markdown
|
|
34
|
-
License-File: LICENSE
|
|
35
|
-
Requires-Dist: torch
|
|
36
|
-
Requires-Dist: numpy
|
|
37
|
-
Requires-Dist: typing_extensions
|
|
38
|
-
Dynamic: license-file
|
|
39
|
-
|
|
40
|
-

|
|
41
|
-
|
|
42
|
-
# torchzero
|
|
43
|
-
|
|
44
|
-
**Modular optimization library for PyTorch**
|
|
45
|
-
|
|
46
|
-
`torchzero` is a PyTorch library providing a highly modular framework for creating and experimenting with a huge number of various optimization algorithms - various momentum techniques, gradient clipping, gradient approximations, line searches, quasi newton methods and more. All algorithms are implemented as modules that can be chained together freely.
|
|
47
|
-
|
|
48
|
-
## Installation
|
|
49
|
-
|
|
50
|
-
```bash
|
|
51
|
-
pip install torchzero
|
|
52
|
-
```
|
|
53
|
-
|
|
54
|
-
pip version is always the latest one. Or install from this repo
|
|
55
|
-
|
|
56
|
-
```bash
|
|
57
|
-
pip install git+https://github.com/inikishev/torchzero
|
|
58
|
-
```
|
|
59
|
-
|
|
60
|
-
**Dependencies:**
|
|
61
|
-
|
|
62
|
-
* Python >= 3.10
|
|
63
|
-
* `torch`
|
|
64
|
-
* `numpy`
|
|
65
|
-
* `typing_extensions`
|
|
66
|
-
|
|
67
|
-
## Quick Start / Usage Example
|
|
68
|
-
|
|
69
|
-
Basic example:
|
|
70
|
-
|
|
71
|
-
```python
|
|
72
|
-
import torch
|
|
73
|
-
from torch import nn
|
|
74
|
-
import torchzero as tz
|
|
75
|
-
|
|
76
|
-
# Define a simple model
|
|
77
|
-
model = nn.Linear(10, 1)
|
|
78
|
-
criterion = nn.MSELoss()
|
|
79
|
-
inputs = torch.randn(5, 10)
|
|
80
|
-
targets = torch.randn(5, 1)
|
|
81
|
-
|
|
82
|
-
# Create an optimizer
|
|
83
|
-
# The order of modules matters:
|
|
84
|
-
# 1. ClipValue: clips gradients to (-10, 10) range.
|
|
85
|
-
# 2. Adam: applies Adam update rule to clipped gradients.
|
|
86
|
-
# 3. NormalizeByEMA: stabilizes the update by normalizing it to an exponential
|
|
87
|
-
# moving average of past updates.
|
|
88
|
-
# 4. WeightDecay - decoupled weight decay (can also move after LR to fully decouple)
|
|
89
|
-
# 5. LR: Scales the computed update by the learning rate (supports LR schedulers).
|
|
90
|
-
optimizer = tz.Modular(
|
|
91
|
-
model.parameters(),
|
|
92
|
-
tz.m.ClipValue(10),
|
|
93
|
-
tz.m.Adam(),
|
|
94
|
-
tz.m.NormalizeByEMA(max_ema_growth=1.1),
|
|
95
|
-
tz.m.WeightDecay(1e-4),
|
|
96
|
-
tz.m.LR(1e-1),
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# Standard training loop
|
|
100
|
-
for epoch in range(100):
|
|
101
|
-
optimizer.zero_grad()
|
|
102
|
-
output = model(inputs)
|
|
103
|
-
loss = criterion(output, targets)
|
|
104
|
-
loss.backward()
|
|
105
|
-
optimizer.step()
|
|
106
|
-
if (epoch+1) % 10 == 0: print(f"Epoch {epoch+1}, Loss: {loss.item()}")
|
|
107
|
-
```
|
|
108
|
-
|
|
109
|
-
## Overview of Available Modules
|
|
110
|
-
|
|
111
|
-
`torchzero` provides a huge number of various modules:
|
|
112
|
-
|
|
113
|
-
* **Optimizers**: Optimization algorithms.
|
|
114
|
-
* `Adam`, `Adan`, `Adagrad`, `ESGD`, `FullMatrixAdagrad`, `LMAdagrad`, `AdaHessian`, `AdaptiveHeavyBall`, `OrthoGrad`, `Lion`, `MARS`, `MatrixMomentum`, `AdaptiveMatrixMomentum`, `Muon`, `RMSprop`, `Rprop`, `SAM`, `ASAM`, `MSAM`, `Shampoo`, `SOAP`, `SophiaH`.
|
|
115
|
-
|
|
116
|
-
Additionally many other optimizers can be easily defined via modules:
|
|
117
|
-
* Grams: `[tz.m.Adam(), tz.m.GradSign()]`
|
|
118
|
-
* LaProp: `[tz.m.RMSprop(), tz.m.EMA(0.9)]`
|
|
119
|
-
* Signum: `[tz.m.HeavyBall(), tz.m.Sign()]`
|
|
120
|
-
* Efficient full-matrix version of any diagonal optimizer, like Adam: `[tz.m.LMAdagrad(beta=0.999, inner=tz.m.EMA(0.9)), tz.m.Debias(0.9, 0.999)]`
|
|
121
|
-
* Cautious version of any optimizer, like SOAP: `[tz.m.SOAP(), tz.m.Cautious()]`
|
|
122
|
-
|
|
123
|
-
* **Momentum**:
|
|
124
|
-
* `HeavyBall`: Classic momentum (Polyak's momentum).
|
|
125
|
-
* `NAG`: Nesterov Accelerated Gradient.
|
|
126
|
-
* `EMA`: Exponential moving average.
|
|
127
|
-
* `Averaging` (`MedianAveraging`, `WeightedAveraging`): Simple, median, or weighted averaging of updates.
|
|
128
|
-
* `Cautious`, `ScaleByGradCosineSimilarity`: Momentum cautioning.
|
|
129
|
-
|
|
130
|
-
* **Stabilization**: Gradient stabilization techniques.
|
|
131
|
-
* `ClipNorm`: Clips gradient L2 norm.
|
|
132
|
-
* `ClipValue`: Clips gradient values element-wise.
|
|
133
|
-
* `Normalize`: Normalizes gradients to unit norm.
|
|
134
|
-
* `Centralize`: Centralizes gradients by subtracting the mean.
|
|
135
|
-
* `ClipNormByEMA`, `NormalizeByEMA`, `ClipValueByEMA`: Clipping/Normalization based on EMA of past values.
|
|
136
|
-
* `ClipNormGrowth`, `ClipValueGrowth`: Limits norm or value growth.
|
|
137
|
-
|
|
138
|
-
* **Gradient approximations**: Methods for approximating gradients.
|
|
139
|
-
* `FDM`: Finite difference method.
|
|
140
|
-
* `RandomizedFDM` (`MeZO`, `SPSA`, `RDSA`, `Gaussian smoothing`): Randomized finite difference methods (also subspaces).
|
|
141
|
-
* `ForwardGradient`: Randomized gradient approximation via forward mode automatic differentiation.
|
|
142
|
-
|
|
143
|
-
* **Second order**: Second order methods.
|
|
144
|
-
* `Newton`: Classic Newton's method.
|
|
145
|
-
* `InverseFreeNewton`: Inverse-free version of Newton's method.
|
|
146
|
-
* `NewtonCG`: Matrix-free newton's method with conjugate gradient or minimal residual solvers.
|
|
147
|
-
* `TruncatedNewtonCG`: Steihaug-Toint Trust-region NewtonCG via a truncated CG solver.
|
|
148
|
-
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
149
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning.
|
|
150
|
-
* `HigherOrderNewton`: Higher order Newton's method with trust region.
|
|
151
|
-
|
|
152
|
-
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
153
|
-
* `LBFGS`: Limited-memory BFGS.
|
|
154
|
-
* `LSR1`: Limited-memory SR1.
|
|
155
|
-
* `OnlineLBFGS`: Online LBFGS.
|
|
156
|
-
* `BFGS`, `DFP`, `ICUM`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `NewSSM`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`, `ShorR`: Full-matrix quasi-newton methods.
|
|
157
|
-
* `DiagonalBFGS`, `DiagonalSR1`, `DiagonalQuasiCauchi`, `DiagonalWeightedQuasiCauchi`, `DNRTR`, `NewDQN`: Diagonal quasi-newton methods.
|
|
158
|
-
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
|
|
159
|
-
|
|
160
|
-
* **Trust Region** Trust region can work with exact hessian or any of the quasi-newton methods (L-BFGS support is WIP)
|
|
161
|
-
* `TrustCG`: Trust-region, uses a Steihaug-Toint truncated CG solver.
|
|
162
|
-
* `CubicRegularization`: Cubic regularization, works better with exact hessian.
|
|
163
|
-
|
|
164
|
-
* **Line Search**:
|
|
165
|
-
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
166
|
-
* `StrongWolfe`: Cubic interpolation line search satisfying strong Wolfe conditions.
|
|
167
|
-
* `ScipyMinimizeScalar`: Wrapper for SciPy's scalar minimization for line search.
|
|
168
|
-
|
|
169
|
-
* **Learning Rate**:
|
|
170
|
-
* `LR`: Controls learning rate and adds support for LR schedulers.
|
|
171
|
-
* `PolyakStepSize`: Polyak's subgradient method.
|
|
172
|
-
* `BarzilaiBorwein`: Barzilai-Borwein step-size.
|
|
173
|
-
* `Warmup`, `WarmupNormCLip`: Learning rate warmup.
|
|
174
|
-
|
|
175
|
-
* **Projections**: This can implement things like GaLore but I haven't done that yet.
|
|
176
|
-
<!-- * `FFTProjection`, `DCTProjection`: Use any update rule in Fourier or DCT domain (doesn't seem to help though).
|
|
177
|
-
* `VectorProjection`, `TensorizeProjection`, `BlockPartition`, `TensorNormsProjection`: Structural projection methods (for block BFGS etc.). -->
|
|
178
|
-
This is WIP
|
|
179
|
-
* `To`: this casts everything to any other dtype and device for other modules, e.g. if you want better precision
|
|
180
|
-
* `ViewAsReal`: put if you have complex paramters.
|
|
181
|
-
|
|
182
|
-
* **Smoothing**: Smoothing-based optimization methods.
|
|
183
|
-
* `LaplacianSmoothing`: Laplacian smoothing for gradients (implements Laplacian Smooth GD).
|
|
184
|
-
* `GaussianHomotopy`: Smoothing via randomized Gaussian homotopy.
|
|
185
|
-
|
|
186
|
-
* **Weight Decay**:.
|
|
187
|
-
* `WeightDecay`: Standard L2 or L1 weight decay.
|
|
188
|
-
|
|
189
|
-
* **Ops**: This has low level operations, also stuff like grafting and gradient accumulation.
|
|
190
|
-
|
|
191
|
-
* **Wrappers**.
|
|
192
|
-
* `Wrap`: Wraps any PyTorch optimizer, allowing to use it as a module.
|
|
193
|
-
|
|
194
|
-
* **Experimental**: various horrible atrocities
|
|
195
|
-
|
|
196
|
-
A complete list of modules is available in the [documentation](https://torchzero.readthedocs.io/en/latest/autoapi/torchzero/modules/index.html).
|
|
197
|
-
|
|
198
|
-
## Advanced Usage
|
|
199
|
-
|
|
200
|
-
### Closure
|
|
201
|
-
|
|
202
|
-
Certain modules, particularly line searches and gradient approximations require a closure, similar to L-BFGS in PyTorch. Also some modules require closure to accept an additional `backward` argument, refer to example below:
|
|
203
|
-
|
|
204
|
-
```python
|
|
205
|
-
# basic training loop
|
|
206
|
-
for inputs, targets in dataloader:
|
|
207
|
-
|
|
208
|
-
def closure(backward=True): # make sure it is True by default
|
|
209
|
-
preds = model(inputs)
|
|
210
|
-
loss = criterion(preds, targets)
|
|
211
|
-
|
|
212
|
-
if backward: # gradient approximations always call with backward=False.
|
|
213
|
-
optimizer.zero_grad()
|
|
214
|
-
loss.backward()
|
|
215
|
-
|
|
216
|
-
return loss
|
|
217
|
-
|
|
218
|
-
loss = optimizer.step(closure)
|
|
219
|
-
```
|
|
220
|
-
|
|
221
|
-
The code above will also work with any other optimizer because all PyTorch optimizers and most custom ones support closure, so there is no need to rewrite training loop.
|
|
222
|
-
|
|
223
|
-
Non-batched example (rosenbrock):
|
|
224
|
-
|
|
225
|
-
```py
|
|
226
|
-
import torchzero as tz
|
|
227
|
-
|
|
228
|
-
def rosen(x, y):
|
|
229
|
-
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
|
|
230
|
-
|
|
231
|
-
W = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
232
|
-
|
|
233
|
-
def closure(backward=True):
|
|
234
|
-
loss = rosen(*W)
|
|
235
|
-
if backward:
|
|
236
|
-
W.grad = None # same as opt.zero_grad()
|
|
237
|
-
loss.backward()
|
|
238
|
-
return loss
|
|
239
|
-
|
|
240
|
-
opt = tz.Modular([W], tz.m.NewtonCG(), tz.m.StrongWolfe())
|
|
241
|
-
for step in range(20):
|
|
242
|
-
loss = opt.step(closure)
|
|
243
|
-
print(f'{step} - {loss}')
|
|
244
|
-
```
|
|
245
|
-
|
|
246
|
-
### Module combinations
|
|
247
|
-
|
|
248
|
-
There are practically no rules to the ordering of the modules - anything will work. For example any method can be made zeroth order by putting it after some gradient approximation module such as GaussianSmoothing:
|
|
249
|
-
|
|
250
|
-
```python
|
|
251
|
-
opt = tz.Modular(
|
|
252
|
-
bench.parameters(),
|
|
253
|
-
tz.m.GaussianSmoothing(h=0.01, n_samples=10),
|
|
254
|
-
tz.m.NewtonCG(hvp_method='forward'),
|
|
255
|
-
tz.m.AdaptiveBacktracking(),
|
|
256
|
-
)
|
|
257
|
-
```
|
|
258
|
-
|
|
259
|
-
GaussianSmoothing actually creates a new **closure** which approximates the gradient. To NewtonCG this closure is just like
|
|
260
|
-
any other closure, so it works seamlessly.
|
|
261
|
-
|
|
262
|
-
Any module can be projected (this is how it will work once I implement GaLore, but I haven't done that yet):
|
|
263
|
-
|
|
264
|
-
```python
|
|
265
|
-
tz.m.GaLore([tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(1e-2)])
|
|
266
|
-
```
|
|
267
|
-
|
|
268
|
-
### Low level modules
|
|
269
|
-
|
|
270
|
-
torchzero provides a lot of low-level modules that can be used to recreate update rules, or combine existing update rules
|
|
271
|
-
in new ways. Here are some equivalent ways to make Adam in order of their involvement:
|
|
272
|
-
|
|
273
|
-
```python
|
|
274
|
-
tz.m.Adam()
|
|
275
|
-
```
|
|
276
|
-
|
|
277
|
-
```python
|
|
278
|
-
# Adam is debiased RMSprop applied to EMA
|
|
279
|
-
tz.m.RMSprop(0.999, debiased=True, init='zeros', inner=tz.m.EMA(0.9))
|
|
280
|
-
```
|
|
281
|
-
|
|
282
|
-
```python
|
|
283
|
-
tz.m.DivModules(
|
|
284
|
-
tz.m.EMA(0.9, debiased=True),
|
|
285
|
-
[tz.m.SqrtEMASquared(0.999, debiased=True), tz.m.Add(1e-8)]
|
|
286
|
-
)
|
|
287
|
-
```
|
|
288
|
-
|
|
289
|
-
```python
|
|
290
|
-
tz.m.DivModules(
|
|
291
|
-
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
|
|
292
|
-
[tz.m.EMASquared(0.999), tz.m.Sqrt(), tz.m.Add(1e-8)]
|
|
293
|
-
)
|
|
294
|
-
```
|
|
295
|
-
|
|
296
|
-
```python
|
|
297
|
-
tz.m.DivModules(
|
|
298
|
-
[tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
|
|
299
|
-
[
|
|
300
|
-
tz.m.Pow(2),
|
|
301
|
-
tz.m.EMA(0.999),
|
|
302
|
-
tz.m.AccumulateMaximum() if amsgrad else tz.m.Identity(),
|
|
303
|
-
tz.m.Sqrt(),
|
|
304
|
-
tz.m.Debias2(beta=0.999),
|
|
305
|
-
tz.m.Add(1e-8)]
|
|
306
|
-
)
|
|
307
|
-
```
|
|
308
|
-
|
|
309
|
-
### Quick guide to implementing new modules
|
|
310
|
-
|
|
311
|
-
Modules are quite similar to torch.optim.Optimizer, the main difference is that everything is stored in the Vars object,
|
|
312
|
-
not in the module itself. Also both per-parameter settings and state are stored in per-parameter dictionaries. Feel free to modify the example below.
|
|
313
|
-
|
|
314
|
-
```python
|
|
315
|
-
import torch
|
|
316
|
-
from torchzero.core import Module, Var
|
|
317
|
-
|
|
318
|
-
class HeavyBall(Module):
|
|
319
|
-
def __init__(self, momentum: float = 0.9, dampening: float = 0):
|
|
320
|
-
defaults = dict(momentum=momentum, dampening=dampening)
|
|
321
|
-
super().__init__(defaults)
|
|
322
|
-
|
|
323
|
-
def step(self, var: Var):
|
|
324
|
-
# Var object holds all attributes used for optimization - parameters, gradient, update, etc.
|
|
325
|
-
# a module takes a Var object, modifies it or creates a new one, and returns it
|
|
326
|
-
# Var has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
327
|
-
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
328
|
-
|
|
329
|
-
params = var.params
|
|
330
|
-
update = var.get_update() # list of tensors
|
|
331
|
-
|
|
332
|
-
exp_avg_list = []
|
|
333
|
-
for p, u in zip(params, update):
|
|
334
|
-
state = self.state[p]
|
|
335
|
-
settings = self.settings[p]
|
|
336
|
-
momentum = settings['momentum']
|
|
337
|
-
dampening = settings['dampening']
|
|
338
|
-
|
|
339
|
-
if 'momentum_buffer' not in state:
|
|
340
|
-
state['momentum_buffer'] = torch.zeros_like(p)
|
|
341
|
-
|
|
342
|
-
buf = state['momentum_buffer']
|
|
343
|
-
u *= 1 - dampening
|
|
344
|
-
|
|
345
|
-
buf.mul_(momentum).add_(u)
|
|
346
|
-
|
|
347
|
-
# clone because further modules might modify exp_avg in-place
|
|
348
|
-
# and it is part of self.state
|
|
349
|
-
exp_avg_list.append(buf.clone())
|
|
350
|
-
|
|
351
|
-
# set new update to var
|
|
352
|
-
var.update = exp_avg_list
|
|
353
|
-
return var
|
|
354
|
-
```
|
|
355
|
-
|
|
356
|
-
More in-depth guide will be available in the documentation in the future.
|
|
357
|
-
|
|
358
|
-
## Other stuff
|
|
359
|
-
|
|
360
|
-
There are also wrappers providing `torch.optim.Optimizer` interface for various other libraries. When using those, make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
361
|
-
|
|
362
|
-
---
|
|
363
|
-
|
|
364
|
-
### Scipy
|
|
365
|
-
|
|
366
|
-
#### torchzero.optim.wrappers.scipy.ScipyMinimize
|
|
367
|
-
|
|
368
|
-
A wrapper for `scipy.optimize.minimize` with gradients and hessians supplied by pytorch autograd. Scipy provides implementations of the following methods: `'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'`.
|
|
369
|
-
|
|
370
|
-
#### torchzero.optim.wrappers.scipy.ScipyDE, ScipyDualAnnealing, ScipySHGO, ScipyDIRECT, ScipyBrute
|
|
371
|
-
|
|
372
|
-
Equivalent wrappers for other derivative free solvers available in `scipy.optimize`
|
|
373
|
-
|
|
374
|
-
---
|
|
375
|
-
|
|
376
|
-
### NLOpt
|
|
377
|
-
|
|
378
|
-
#### torchzero.optim.wrappers.nlopt.NLOptWrapper
|
|
379
|
-
|
|
380
|
-
A wrapper for [NLOpt](https://github.com/stevengj/nlopt) with gradients supplied by pytorch autograd. NLOpt is another popular library with many gradient based and gradient free [algorithms](https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/)
|
|
381
|
-
|
|
382
|
-
---
|
|
383
|
-
|
|
384
|
-
### Nevergrad
|
|
385
|
-
|
|
386
|
-
#### torchzero.optim.wrappers.nevergrad.NevergradWrapper
|
|
387
|
-
|
|
388
|
-
A wrapper for [nevergrad](https://facebookresearch.github.io/nevergrad/) which has a huge library of gradient free [algorithms](https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers)
|
|
389
|
-
|
|
390
|
-
---
|
|
391
|
-
|
|
392
|
-
### fast-cma-es
|
|
393
|
-
|
|
394
|
-
#### torchzero.optim.wrappers.fcmaes.FcmaesWrapper
|
|
395
|
-
|
|
396
|
-
A wrapper for [fast-cma-es](https://github.com/dietmarwo/fast-cma-es), which implements various gradient free algorithms. Notably it includes [BITEOPT](https://github.com/avaneev/biteopt) which seems to have very good performance in benchmarks.
|
|
397
|
-
|
|
398
|
-
# License
|
|
399
|
-
|
|
400
|
-
This project is licensed under the MIT License
|
|
401
|
-
|
|
402
|
-
# Project Links
|
|
403
|
-
|
|
404
|
-
The documentation is available at <https://torchzero.readthedocs.io/en/latest/>
|
|
@@ -1,159 +0,0 @@
|
|
|
1
|
-
docs/source/conf.py,sha256=Kd0Uyu6WnhSHEyTbOEjxoaUg4sAu0AxN19raSARtltE,1883
|
|
2
|
-
docs/source/docstring template.py,sha256=lIf4Jdkxd-Vr0vOuL9IOTCMOxw5ENsmZDLXKv1eO9ns,1585
|
|
3
|
-
tests/test_identical.py,sha256=PJnQtSO3aHZYMQolHmoB26BEUPD_Gpmdh2_M0tfUfm0,11502
|
|
4
|
-
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
5
|
-
tests/test_opts.py,sha256=pAeyDIT0Q4SXBZqR9W_IUjwAEBcMnYr3zE0N4R0xn8w,42509
|
|
6
|
-
tests/test_tensorlist.py,sha256=SwzLKLrs2ppMtm_7UrfTDTlD-ObZd7JQ_FNHbp059tc,72460
|
|
7
|
-
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
8
|
-
tests/test_vars.py,sha256=2BoawNdDAnnNh_vv49_peJMnHvaQjp_sfnca1nosTWY,6766
|
|
9
|
-
torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
|
|
10
|
-
torchzero/core/__init__.py,sha256=Zib_4is13LFAabp_7VU8QXZpQEEZGzsH94vgRI0HxAg,150
|
|
11
|
-
torchzero/core/module.py,sha256=BfU4YMjwLrwcz24XAfL-cZx05cESIimViKUStJKBEHM,32872
|
|
12
|
-
torchzero/core/transform.py,sha256=sBgEyQVm141v99lnosusNIMWaReuWKuMyzkJha_WwKg,16440
|
|
13
|
-
torchzero/modules/__init__.py,sha256=0Gk6XK32FKxtiW9rh-0Plql2dghHn3Ms1F-Ymn4oVzw,386
|
|
14
|
-
torchzero/modules/functional.py,sha256=hmJaxB7U9X9nsT1Z5aPSqsw5HsQfL2ns1YS8AWdul6c,6948
|
|
15
|
-
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
16
|
-
torchzero/modules/clipping/clipping.py,sha256=6d-LPCI4zqlcV9fXK8rtRLiReyt8lMeQhmt1gsqNljs,14897
|
|
17
|
-
torchzero/modules/clipping/ema_clipping.py,sha256=PNUTvixvc0wdjtWzja6pEzXbNpyXtGxj_H15umWx4zc,6608
|
|
18
|
-
torchzero/modules/clipping/growth_clipping.py,sha256=mHn6BQqfHFWnfVjYG_Dokb3VjfSK2QVGsy2rf0Z-RMY,6657
|
|
19
|
-
torchzero/modules/experimental/__init__.py,sha256=qV-VaBnRsLFtv6T6R9Imkd1G81QR4O-9_kDbCAwJXeY,1464
|
|
20
|
-
torchzero/modules/experimental/absoap.py,sha256=U3nLAV_vxl6HjJhqi8FlK8K6AMLoiZ-deykEshhnCC0,9916
|
|
21
|
-
torchzero/modules/experimental/adadam.py,sha256=PARjM2kRmJ7ifYsI83tADKCuvSZYAoT2vR4Gj2aZ-SA,4103
|
|
22
|
-
torchzero/modules/experimental/adamY.py,sha256=Rr9vXjFPWTfIHnnhGQAfVAQnfANNgcrFm_R8vJsU1to,4043
|
|
23
|
-
torchzero/modules/experimental/adam_lambertw.py,sha256=FXZiTJKVRbXSu9-_boZGYoCqBlh2035mwsagq75qyeA,5323
|
|
24
|
-
torchzero/modules/experimental/adaptive_step_size.py,sha256=OJseQX9sd9F58pMC5JbVNm7PtovMXL4sMwQg3jooVtg,3494
|
|
25
|
-
torchzero/modules/experimental/adasoap.py,sha256=vcgWEgDdqmgimt5bGgvznCnxkkathGO0engd1xo7M4s,7491
|
|
26
|
-
torchzero/modules/experimental/cosine.py,sha256=0Cc42Wd1sMrjm-YxmpcwCCsGpLv3H83rL-XAtrgZhb4,9155
|
|
27
|
-
torchzero/modules/experimental/cubic_adam.py,sha256=wHJKm9bO24Xvtwunz_1Kz7mGi_C-syupixiDaBnYx2Q,2787
|
|
28
|
-
torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
|
|
29
|
-
torchzero/modules/experimental/dct.py,sha256=Iv8ZxGhTOIm3NHS4zxoFG9K9BEwtrJqsKApctiIjnxg,2463
|
|
30
|
-
torchzero/modules/experimental/eigendescent.py,sha256=Pdz7QUbM3pD3DTsTC0nZ0AfOe2pj-WVPPkbnw8lDZ3c,4725
|
|
31
|
-
torchzero/modules/experimental/etf.py,sha256=ul167I1qAbYeTmTPG_WFLLlE1MEsNXxVsTWd9s2YC9g,6125
|
|
32
|
-
torchzero/modules/experimental/exp_adam.py,sha256=yhR5-NGflbEJrSAe0ps4xgAM-eFI-gAdS6cgZIJDgaI,4100
|
|
33
|
-
torchzero/modules/experimental/expanded_lbfgs.py,sha256=M58cCaeLZXGqZwyaeGhi-UAyCsnnJvLAYIZ64r0tQNE,5649
|
|
34
|
-
torchzero/modules/experimental/fft.py,sha256=YEUKdAXNX8BCZYXKV5uWWU8aTlGjpFTUSpIEwIG-_fM,3050
|
|
35
|
-
torchzero/modules/experimental/gradmin.py,sha256=UixSLdca4ekYHOipEivdXfBAV-uEL9TZm5nCFXVaNco,3684
|
|
36
|
-
torchzero/modules/experimental/hnewton.py,sha256=_Gv4O2x0qYBxGtkCuYuzL21VuI5wTn1sTEegk17d6X4,3036
|
|
37
|
-
torchzero/modules/experimental/modular_lbfgs.py,sha256=d40yRi6NN2Au7-UQ1akMkET0PWhEFAhGKAYoQBDmqFQ,10671
|
|
38
|
-
torchzero/modules/experimental/newton_solver.py,sha256=3dZ7FG-2vGxJKkFF9P2LCs-LI_epcvZbyNtJOtw47pg,3055
|
|
39
|
-
torchzero/modules/experimental/newtonnewton.py,sha256=cRL4dKsDAN8tHPyHQkLbTGxkHfemCU6re-n4odV3Ik4,3324
|
|
40
|
-
torchzero/modules/experimental/parabolic_search.py,sha256=2GgE4cq5QkJYZprADIplQfbPWRJRGFmToYTScJkR0tg,6328
|
|
41
|
-
torchzero/modules/experimental/reduce_outward_lr.py,sha256=ui_39wNdf5J2FOQtQFk0WUA8DuicwEp0kepccbq8rI0,1309
|
|
42
|
-
torchzero/modules/experimental/structural_projections.py,sha256=lrySQZOq7VhL_VqU7dIJRsypxA16cUliQYkj5-N2B2I,4187
|
|
43
|
-
torchzero/modules/experimental/subspace_preconditioners.py,sha256=RdG-RoPF6AiFVphrVlb6egNyYI0e_eHoENUWqKJ4icQ,5170
|
|
44
|
-
torchzero/modules/experimental/tensor_adagrad.py,sha256=y29i6BGXwv9lwrTRDzq2YRSngQmfZnreRIeH1NGzpBo,1572
|
|
45
|
-
torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
|
|
46
|
-
torchzero/modules/grad_approximation/fdm.py,sha256=K_D0fKwspg21Opo2xTG4I34gLDmcaYBp5NUzlaQnjxQ,4490
|
|
47
|
-
torchzero/modules/grad_approximation/forward_gradient.py,sha256=AoezoYxXii2gKpIGO7BOZkLb2weYwxrWAKpHL7hrW9Y,4313
|
|
48
|
-
torchzero/modules/grad_approximation/grad_approximator.py,sha256=HO-XaNRF3ZwMduBP02V0oabmSRgqmDGPlKkWfDVDPW8,4740
|
|
49
|
-
torchzero/modules/grad_approximation/rfdm.py,sha256=omarcZyMgJomJwxQ_b7ulE6eK6aW3JP_Sh-jcX5DhR4,23434
|
|
50
|
-
torchzero/modules/higher_order/__init__.py,sha256=W94CY8K1NFxs9TPi415UssKVKz5MV_bH9adax1uZsYM,50
|
|
51
|
-
torchzero/modules/higher_order/higher_order_newton.py,sha256=_v5v0WY07CvZn9QPIS89FxEZ2tNfd8Bkamt1o12_mLQ,12255
|
|
52
|
-
torchzero/modules/line_search/__init__.py,sha256=9ja1Dspfuzu9UxGbU5-t0bFeBcdwoX9Fl_aSMR-AXnQ,219
|
|
53
|
-
torchzero/modules/line_search/adaptive.py,sha256=Uj7lAIzpgy89ddlwA4VcEEIfcNJSbGA5HH3ncuzHrTU,2926
|
|
54
|
-
torchzero/modules/line_search/backtracking.py,sha256=dyXgfrIJ_IO7W4p8GqJNPc4r_igU4X4ljLCLNKyY2Tw,8246
|
|
55
|
-
torchzero/modules/line_search/line_search.py,sha256=_u59XYFkRsIKuT1H4Bz7qAHr3Ldzxbup71OeqDGxMfs,9724
|
|
56
|
-
torchzero/modules/line_search/polynomial.py,sha256=KlK0d9qaphxS0s8B5rlt-yIUYNuV-5O24STcx4vN2Ic,9056
|
|
57
|
-
torchzero/modules/line_search/scipy.py,sha256=eGplW1L8kQKdRbt9PPpvZ6MMekDq5KsjurhSpN9QCnY,2301
|
|
58
|
-
torchzero/modules/line_search/strong_wolfe.py,sha256=F5962HTHdPWgvWHwnUofCqFxfKsCu5p8Ic-aRbn7wVg,8458
|
|
59
|
-
torchzero/modules/misc/__init__.py,sha256=cZpMkZQubuzquhFZV-yELrDMznqhhCibmr0CBOR0ZpU,693
|
|
60
|
-
torchzero/modules/misc/debug.py,sha256=iuWg5egoMnG6y3Cyd423xS7BRVYiwZq9575d7A7U3Dg,1652
|
|
61
|
-
torchzero/modules/misc/escape.py,sha256=1XgNmT4pOptaXHSWEONkUPpcYnIujm5gdK6n_-zmw20,1821
|
|
62
|
-
torchzero/modules/misc/gradient_accumulation.py,sha256=6yXRUxD_f3Zfx83UyCvPJ-56XN4GJjEQcNIDlvFtuuY,2590
|
|
63
|
-
torchzero/modules/misc/misc.py,sha256=VTQZAcfQBo2yudy1u1lyHhmaAmQlxzVcZTHcXXnUeTM,13470
|
|
64
|
-
torchzero/modules/misc/multistep.py,sha256=rAPCALSHXjVNxR8d1CA3RFP_xnN6j5KksjB6yl8vtng,5585
|
|
65
|
-
torchzero/modules/misc/regularization.py,sha256=R8ya7HEF2MLtcAr7GS9IjXwJ4xh0lJWMdWMIRfwL42s,6279
|
|
66
|
-
torchzero/modules/misc/split.py,sha256=ebc95OZjC-Vs73JeTkL--eZrtKijg7lPN0hmD0Whfxc,3195
|
|
67
|
-
torchzero/modules/misc/switch.py,sha256=72mfY_uIVyTllwuR21_K7QC8IQFP7JMKzH4K2nAx0Wc,3726
|
|
68
|
-
torchzero/modules/momentum/__init__.py,sha256=tI2I5zSQB7aTwEn371wvUTy2O2n_-KVCafjBv-OMsYE,545
|
|
69
|
-
torchzero/modules/momentum/averaging.py,sha256=gZRjHb443HuFF03p3Oh2rfgh2Qu8sJBxc_8NR-ircaA,3241
|
|
70
|
-
torchzero/modules/momentum/cautious.py,sha256=QP3Sqc8nMb3xTDDDfGwFn5AWvN4EI5U-CCcZb-F5oX0,8266
|
|
71
|
-
torchzero/modules/momentum/ema.py,sha256=9OdMF20RYnEkwe9Xu2dCAAiI0qY2MQvhS87bKP7ptTI,10755
|
|
72
|
-
torchzero/modules/momentum/experimental.py,sha256=WnM9FUKPxyFNiKU6Ip7wqqYxHbXuaMKOcLjjomfENb4,6916
|
|
73
|
-
torchzero/modules/momentum/matrix_momentum.py,sha256=gZeTJZbhgixCOkE9Jyowtva58hl5vsH9iTqGC54FWFs,8047
|
|
74
|
-
torchzero/modules/momentum/momentum.py,sha256=Yx35jtbLb1syVFcTiNSoZPoUPmdsUy3QpoNWcN4sC9w,2664
|
|
75
|
-
torchzero/modules/ops/__init__.py,sha256=1q9CBo6OXWXDgyjvKKTlG0EdP4ASIvkWFXtd6LOuU88,1083
|
|
76
|
-
torchzero/modules/ops/accumulate.py,sha256=kyjiC9M9fugpG5Pc07XUi6GEWBvRi8iJ-7_Mb1SXQzE,3665
|
|
77
|
-
torchzero/modules/ops/binary.py,sha256=mIeaa3v5Bk7mwzSTC0jGMLhKf-Ujg6aFbSia2yo-3JQ,12199
|
|
78
|
-
torchzero/modules/ops/multi.py,sha256=DpabTYj0sic5dmosnmj7lgIX3dbmcgl0h9XfzKpbaus,8918
|
|
79
|
-
torchzero/modules/ops/reduce.py,sha256=uLCq493hFy_Ib22GjIKtMHTTObK3RDmubGHTVqgFgg8,6339
|
|
80
|
-
torchzero/modules/ops/unary.py,sha256=EFA_A834KmA6Ec3pZWH5XxZ9OzAhZZudwAwsP4GWZA0,5476
|
|
81
|
-
torchzero/modules/ops/utility.py,sha256=9Skxkt4RO79OBdw95wOKhqKN2RMdZg9emO7R9q2d5oU,3767
|
|
82
|
-
torchzero/modules/optimizers/__init__.py,sha256=IJaLoZ39rbB4GSW9rLKrfSCh5FsAkFy2ww5MhJ6MYnE,817
|
|
83
|
-
torchzero/modules/optimizers/adagrad.py,sha256=p-DWbhGuuogldiFPNxxQfJ8AA5Tsd4UwGOIyX7GT0WE,5892
|
|
84
|
-
torchzero/modules/optimizers/adahessian.py,sha256=vOJfwGi7ypfi7vifCMJfGew-McdGJKQM3TmkT-OUgI0,8682
|
|
85
|
-
torchzero/modules/optimizers/adam.py,sha256=SkJ7UJ1BOAgfregmzYDFo_3cgPNke_RK9B58hOal_Zg,3954
|
|
86
|
-
torchzero/modules/optimizers/adan.py,sha256=aOG6KGLU4oHYeQn3JB-A4NQ-279QpHA7firY3kkhFR4,3311
|
|
87
|
-
torchzero/modules/optimizers/adaptive_heavyball.py,sha256=DnkWHA0GBLIKCq8nWh76fZA6PnJ3eKsJDBXWKnZ_uIs,2127
|
|
88
|
-
torchzero/modules/optimizers/esgd.py,sha256=WXwYPA-qTA_QW9h4NDwNaly9gbi1uvMQ-5fSuLqnPkQ,6413
|
|
89
|
-
torchzero/modules/optimizers/ladagrad.py,sha256=HQb7LuZnG8SvS8JWqu7JJz_owlkyT-fnqeICrJBQxbc,7314
|
|
90
|
-
torchzero/modules/optimizers/lion.py,sha256=XFyglRNdnP1l8CmEZ7L_ZB8HWiR03BsZ_PEFCvHijb8,1127
|
|
91
|
-
torchzero/modules/optimizers/mars.py,sha256=7tr32x2eQNu8ZVQAPnLIkM2kkYp7S57uiDywTdqy1uY,2710
|
|
92
|
-
torchzero/modules/optimizers/msam.py,sha256=nvoo6smewR3hiCCymZQiB3DlCvLBGxfxlovJF2bwwsc,6588
|
|
93
|
-
torchzero/modules/optimizers/muon.py,sha256=AZKpmkVUjukXtI7Pb9PKDEeycreLF6qYlIMSbV_9IuA,10463
|
|
94
|
-
torchzero/modules/optimizers/orthograd.py,sha256=KbQuudjKgYVJcq1jRW_YmR2pPnwmAwyx9X_vrJAJgN4,2029
|
|
95
|
-
torchzero/modules/optimizers/rmsprop.py,sha256=ugZLfH4dXvHTxawtGWQL6xSfsjBDl_t1s29aFN9FMuY,4345
|
|
96
|
-
torchzero/modules/optimizers/rprop.py,sha256=nFpnqcXevGkUcPWERDX9gsiBCGgOi4pyPFloL68zwPY,11984
|
|
97
|
-
torchzero/modules/optimizers/sam.py,sha256=yEhXAS3v62nhAvs63RZ80VfZ93MaQ0cyMQziFdy6e2U,5711
|
|
98
|
-
torchzero/modules/optimizers/shampoo.py,sha256=m_XOvo2Eb1HP8QqYFPsT0rgczJ8HqKjh67QmtaY9dVg,9544
|
|
99
|
-
torchzero/modules/optimizers/soap.py,sha256=MXQ8fdBzLyFtgW34fnmY3hQqv3q4QwEthho9kK-72VE,11305
|
|
100
|
-
torchzero/modules/optimizers/sophia_h.py,sha256=dgQwjij5R4zdESYoKhc4BMhb6dKkDuEvjlL4bDdeQtw,7213
|
|
101
|
-
torchzero/modules/projections/__init__.py,sha256=4LfmBEu_eM4YWmcWQVH4CdI1H0ucCIHDH9tTGigjVPY,136
|
|
102
|
-
torchzero/modules/projections/cast.py,sha256=FJx2Tt1lbQRnOC5wxx3LbOnacLfUluFP6QOXLUCIEPY,2174
|
|
103
|
-
torchzero/modules/projections/galore.py,sha256=GDJ7hf6cdk_Iu2qW0rWaQwYLQAxQEe27FEfOiZvFXHo,252
|
|
104
|
-
torchzero/modules/projections/projection.py,sha256=PU2e9LNfVMnNrXnBDt-hdr5pVtl0TpgiB4b92WUguSs,14005
|
|
105
|
-
torchzero/modules/quasi_newton/__init__.py,sha256=guTCpbAffZyupnThdPxAsLULAmPP3vdPaNfPCe9KB9Y,854
|
|
106
|
-
torchzero/modules/quasi_newton/cg.py,sha256=HCfza5UInco7_hYT8s3duNRTmBdjbw5jscWLKNUiS8w,14453
|
|
107
|
-
torchzero/modules/quasi_newton/diagonal_quasi_newton.py,sha256=bMvIcWifYlJX83UtXFESMw7OdA4AO7tJwlHZwkc5wx0,6555
|
|
108
|
-
torchzero/modules/quasi_newton/lbfgs.py,sha256=BmE5sOFLFoJDlpoSphM5VowMgt7wtEFihbLkdylDXhM,10638
|
|
109
|
-
torchzero/modules/quasi_newton/lsr1.py,sha256=a19a9ABqMiTVJmXe6Woc0sJ1kkhQa3Y6QDouaUNnPt0,7873
|
|
110
|
-
torchzero/modules/quasi_newton/quasi_newton.py,sha256=hKJ9Irmh2pKNfB7Wen4MrDfMrbvzp00FTcPlpFvJLDU,48582
|
|
111
|
-
torchzero/modules/quasi_newton/trust_region.py,sha256=cxOEDeZ8ZhG_w7QXGYnTsF-t5g5zZ39q9Uxb2IXWgAY,15213
|
|
112
|
-
torchzero/modules/second_order/__init__.py,sha256=Trje1qM65yp8WWzuRm-tMTRqfKi4wpI7f8yyZWjhPCw,152
|
|
113
|
-
torchzero/modules/second_order/newton.py,sha256=94LGrQo5Q8aC5DI9S6RSXF0stVcgWzq3JnE9l_BsVUw,12875
|
|
114
|
-
torchzero/modules/second_order/newton_cg.py,sha256=l8FX9vQSVCSkpk5a-M2wEBBjQoODF-T07GFW_tjJxkM,14890
|
|
115
|
-
torchzero/modules/second_order/nystrom.py,sha256=yAJijWCl-K8k63YSJUqE_kXEIFmL_FjDghVjQoutAXo,11352
|
|
116
|
-
torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
|
|
117
|
-
torchzero/modules/smoothing/gaussian.py,sha256=iTsWlMNHuDLoxPRIsm2pAb5cS8OqdRJwCsw-vUTVmpE,7887
|
|
118
|
-
torchzero/modules/smoothing/laplacian.py,sha256=05Y6ft0GHRGkfSDV-g8vlaTlZTiXMr79xDagJb126ug,5302
|
|
119
|
-
torchzero/modules/step_size/__init__.py,sha256=Z8NpB9RYIXhcNx11NWixa7mORPiT4nI1mKQGA7JfC6g,122
|
|
120
|
-
torchzero/modules/step_size/adaptive.py,sha256=3qQr1aaPYEJlkiDSQbuVQ_OVkOq-W4LL7PkHFFgwP2c,4845
|
|
121
|
-
torchzero/modules/step_size/lr.py,sha256=I9-aIxei4Y2XnlOoCKvec2r__cTY_JTwBDlMf2O5D2A,5908
|
|
122
|
-
torchzero/modules/weight_decay/__init__.py,sha256=7UHAiiimsbQ_dHlxxcW87G5cCQFom9Uh_733W_23PWU,93
|
|
123
|
-
torchzero/modules/weight_decay/weight_decay.py,sha256=2MhWRyryplDtB61QyKN7KqBa3mEkhtqXhij8LGR-mYA,5464
|
|
124
|
-
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
125
|
-
torchzero/modules/wrappers/optim_wrapper.py,sha256=B8ZwZf-qzChBfbx-cwL8Rez4AgH7FzvsT7N1S2SUiR8,4417
|
|
126
|
-
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
127
|
-
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
128
|
-
torchzero/optim/utility/split.py,sha256=ZbazNuMTYunm75V_5ard0A_LletGaYAg-Pm2rANJKrE,1610
|
|
129
|
-
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
130
|
-
torchzero/optim/wrappers/directsearch.py,sha256=GQ2nzy9ADqbV_QUMN3IaYecZ0Pzx_3mAasSB4fryTBE,11362
|
|
131
|
-
torchzero/optim/wrappers/fcmaes.py,sha256=o_FchMtDsrEj9XRonHHeyVHPAXTHaU244SzlldgEzLg,4250
|
|
132
|
-
torchzero/optim/wrappers/mads.py,sha256=Zi9u3vNlgNsCaIdYLl2_jgRA_dQrmVAuG0V31BFuct4,3087
|
|
133
|
-
torchzero/optim/wrappers/nevergrad.py,sha256=U_ZAHD_nEsJZ71cJ8TQ_DOZcTmS06EEvUPvaaDRSxWI,4901
|
|
134
|
-
torchzero/optim/wrappers/nlopt.py,sha256=AaVEKfjbrt5DFION44_-g-jQAoVi4lCvBBPU5UDGO9Q,8151
|
|
135
|
-
torchzero/optim/wrappers/optuna.py,sha256=ZZ66aXEypSJMVomphbzHNJnmIOyXS9tqE89YZBPpIuo,2331
|
|
136
|
-
torchzero/optim/wrappers/scipy.py,sha256=Td1AvpLDEPqPVW6IpHbkVW4CpNiUU9r_eyc3qJVHZAY,19352
|
|
137
|
-
torchzero/utils/__init__.py,sha256=4JMKzF3qICE9PSfgXAwb3cPswM5f1JUutWwviev2-0k,875
|
|
138
|
-
torchzero/utils/compile.py,sha256=N8AWLv_7oBUHYornmvvx_L4uynjiD-x5Hj1tBwei3-w,5127
|
|
139
|
-
torchzero/utils/derivatives.py,sha256=IIn4stpMMJxYmGKh1JCH4Gha_a4w8Z5G04uVz2BwMP4,16995
|
|
140
|
-
torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
|
|
141
|
-
torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
|
|
142
|
-
torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
|
|
143
|
-
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
144
|
-
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
145
|
-
torchzero/utils/python_tools.py,sha256=NEyDVJfLBbdwh5m49qiOdIr0NffZRqKhaJ-cktviD1o,3243
|
|
146
|
-
torchzero/utils/tensorlist.py,sha256=WvjhPzGbgRySAsUBFQ7b-39V9rm7jbR1VOeYZQXiiKw,53925
|
|
147
|
-
torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
|
|
148
|
-
torchzero/utils/linalg/__init__.py,sha256=tsUt20_rbA_3pV6NK7yCkGoX1l0D9ayMKwZeySsYxHw,291
|
|
149
|
-
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
150
|
-
torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
|
|
151
|
-
torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
|
|
152
|
-
torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
|
|
153
|
-
torchzero/utils/linalg/solve.py,sha256=JF0i_eJTBRKCs7CONUOV7coPjE46NC5nMaz2JotrvSE,11232
|
|
154
|
-
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
155
|
-
torchzero-0.3.11.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
|
|
156
|
-
torchzero-0.3.11.dist-info/METADATA,sha256=Czo-sKnlVxQ75MhY3D61oD8lusASV0ez_l697dyJBNc,15797
|
|
157
|
-
torchzero-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
158
|
-
torchzero-0.3.11.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
|
|
159
|
-
torchzero-0.3.11.dist-info/RECORD,,
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
MIT License
|
|
2
|
-
|
|
3
|
-
Copyright (c) 2024 inikishev
|
|
4
|
-
|
|
5
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
-
in the Software without restriction, including without limitation the rights
|
|
8
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
-
furnished to do so, subject to the following conditions:
|
|
11
|
-
|
|
12
|
-
The above copyright notice and this permission notice shall be included in all
|
|
13
|
-
copies or substantial portions of the Software.
|
|
14
|
-
|
|
15
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
-
SOFTWARE.
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|