torchzero 0.3.8__tar.gz → 0.3.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.3.8 → torchzero-0.3.10}/PKG-INFO +14 -14
- {torchzero-0.3.8 → torchzero-0.3.10}/README.md +13 -13
- {torchzero-0.3.8 → torchzero-0.3.10}/pyproject.toml +1 -1
- {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_opts.py +55 -22
- {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_tensorlist.py +3 -3
- {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_vars.py +61 -61
- torchzero-0.3.10/torchzero/core/__init__.py +2 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/core/module.py +49 -49
- torchzero-0.3.10/torchzero/core/transform.py +313 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/__init__.py +1 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/clipping.py +10 -10
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/ema_clipping.py +14 -13
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/growth_clipping.py +16 -18
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/__init__.py +12 -3
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/absoap.py +50 -156
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/adadam.py +15 -14
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/adamY.py +17 -27
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/adasoap.py +20 -130
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/curveball.py +12 -12
- torchzero-0.3.10/torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero-0.3.10/torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero-0.3.10/torchzero/modules/experimental/etf.py +172 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/gradmin.py +2 -2
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero-0.3.10/torchzero/modules/experimental/newtonnewton.py +88 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/soapy.py +19 -146
- torchzero-0.3.10/torchzero/modules/experimental/spectral.py +163 -0
- torchzero-0.3.10/torchzero/modules/experimental/structured_newton.py +111 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero-0.3.10/torchzero/modules/experimental/tada.py +38 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/fdm.py +2 -2
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero-0.3.10/torchzero/modules/higher_order/__init__.py +1 -0
- torchzero-0.3.10/torchzero/modules/higher_order/higher_order_newton.py +256 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/backtracking.py +42 -23
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/line_search.py +40 -40
- torchzero-0.3.10/torchzero/modules/line_search/scipy.py +52 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/strong_wolfe.py +21 -32
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/trust_region.py +18 -6
- torchzero-0.3.10/torchzero/modules/lr/__init__.py +2 -0
- torchzero-0.3.8/torchzero/modules/lr/step_size.py → torchzero-0.3.10/torchzero/modules/lr/adaptive.py +22 -26
- torchzero-0.3.10/torchzero/modules/lr/lr.py +63 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/averaging.py +25 -10
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/cautious.py +73 -35
- torchzero-0.3.10/torchzero/modules/momentum/ema.py +224 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/experimental.py +21 -13
- torchzero-0.3.10/torchzero/modules/momentum/matrix_momentum.py +166 -0
- torchzero-0.3.10/torchzero/modules/momentum/momentum.py +63 -0
- torchzero-0.3.10/torchzero/modules/ops/accumulate.py +95 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/binary.py +36 -36
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/debug.py +7 -7
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/misc.py +128 -129
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/multi.py +19 -19
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/reduce.py +16 -16
- torchzero-0.3.10/torchzero/modules/ops/split.py +75 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/switch.py +4 -4
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/unary.py +20 -20
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/utility.py +37 -37
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/adagrad.py +33 -24
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/adam.py +31 -34
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/lion.py +4 -4
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/muon.py +6 -6
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/orthograd.py +4 -5
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/rmsprop.py +13 -16
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/rprop.py +52 -49
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/shampoo.py +17 -23
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/soap.py +12 -19
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/sophia_h.py +13 -13
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/dct.py +4 -4
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/fft.py +6 -6
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/galore.py +1 -1
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/projection.py +57 -57
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/structural.py +17 -17
- torchzero-0.3.10/torchzero/modules/quasi_newton/__init__.py +36 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/cg.py +76 -26
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/lbfgs.py +15 -15
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/lsr1.py +18 -17
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/olbfgs.py +19 -19
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/quasi_newton.py +257 -48
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/newton.py +38 -21
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/newton_cg.py +13 -12
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/nystrom.py +19 -19
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/smoothing/gaussian.py +21 -21
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/smoothing/laplacian.py +7 -9
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero-0.3.10/torchzero/modules/weight_decay/weight_decay.py +86 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero-0.3.10/torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero-0.3.10/torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero-0.3.10/torchzero/optim/wrappers/mads.py +90 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/nevergrad.py +4 -4
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero-0.3.10/torchzero/optim/wrappers/optuna.py +70 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/scipy.py +162 -13
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/__init__.py +2 -6
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/derivatives.py +2 -1
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/optimizer.py +55 -74
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/PKG-INFO +14 -14
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/SOURCES.txt +13 -4
- torchzero-0.3.8/torchzero/core/__init__.py +0 -3
- torchzero-0.3.8/torchzero/core/preconditioner.py +0 -138
- torchzero-0.3.8/torchzero/core/transform.py +0 -252
- torchzero-0.3.8/torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero-0.3.8/torchzero/modules/experimental/spectral.py +0 -288
- torchzero-0.3.8/torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.8/torchzero/modules/line_search/scipy.py +0 -37
- torchzero-0.3.8/torchzero/modules/lr/__init__.py +0 -2
- torchzero-0.3.8/torchzero/modules/lr/lr.py +0 -59
- torchzero-0.3.8/torchzero/modules/momentum/ema.py +0 -173
- torchzero-0.3.8/torchzero/modules/momentum/matrix_momentum.py +0 -124
- torchzero-0.3.8/torchzero/modules/momentum/momentum.py +0 -43
- torchzero-0.3.8/torchzero/modules/ops/accumulate.py +0 -65
- torchzero-0.3.8/torchzero/modules/ops/split.py +0 -75
- torchzero-0.3.8/torchzero/modules/quasi_newton/__init__.py +0 -7
- torchzero-0.3.8/torchzero/modules/weight_decay/weight_decay.py +0 -52
- {torchzero-0.3.8 → torchzero-0.3.10}/LICENSE +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/docs/source/conf.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/setup.cfg +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_identical.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_module.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/functional.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/compile.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/benchmark.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/matrix_funcs.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/orthogonalize.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/qr.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/solve.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/svd.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/ops.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/params.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.10
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -157,13 +157,14 @@ for epoch in range(100):
|
|
|
157
157
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
158
158
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
159
|
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
160
|
+
* `HigherOrderNewton`: Higher order Newton's method with trust region.
|
|
160
161
|
|
|
161
162
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
163
|
* `LBFGS`: Limited-memory BFGS.
|
|
163
164
|
* `LSR1`: Limited-memory SR1.
|
|
164
165
|
* `OnlineLBFGS`: Online LBFGS.
|
|
165
|
-
* `BFGS`, `SR1`, `
|
|
166
|
-
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
|
|
166
|
+
* `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `ColumnUpdatingMethod`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`: Classic full-matrix quasi-newton methods.
|
|
167
|
+
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
|
|
167
168
|
|
|
168
169
|
* **Line Search**:
|
|
169
170
|
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
@@ -312,20 +313,20 @@ not in the module itself. Also both per-parameter settings and state are stored
|
|
|
312
313
|
|
|
313
314
|
```python
|
|
314
315
|
import torch
|
|
315
|
-
from torchzero.core import Module,
|
|
316
|
+
from torchzero.core import Module, Var
|
|
316
317
|
|
|
317
318
|
class HeavyBall(Module):
|
|
318
319
|
def __init__(self, momentum: float = 0.9, dampening: float = 0):
|
|
319
320
|
defaults = dict(momentum=momentum, dampening=dampening)
|
|
320
321
|
super().__init__(defaults)
|
|
321
322
|
|
|
322
|
-
def step(self,
|
|
323
|
-
# a module takes a
|
|
324
|
-
#
|
|
323
|
+
def step(self, var: Var):
|
|
324
|
+
# a module takes a Var object, modifies it or creates a new one, and returns it
|
|
325
|
+
# Var has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
325
326
|
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
326
327
|
|
|
327
|
-
params =
|
|
328
|
-
update =
|
|
328
|
+
params = var.params
|
|
329
|
+
update = var.get_update() # list of tensors
|
|
329
330
|
|
|
330
331
|
exp_avg_list = []
|
|
331
332
|
for p, u in zip(params, update):
|
|
@@ -346,16 +347,15 @@ class HeavyBall(Module):
|
|
|
346
347
|
# and it is part of self.state
|
|
347
348
|
exp_avg_list.append(buf.clone())
|
|
348
349
|
|
|
349
|
-
# set new update to
|
|
350
|
-
|
|
351
|
-
return
|
|
350
|
+
# set new update to var
|
|
351
|
+
var.update = exp_avg_list
|
|
352
|
+
return var
|
|
352
353
|
```
|
|
353
354
|
|
|
354
355
|
There are a some specialized base modules that make it much easier to implement some specific things.
|
|
355
356
|
|
|
356
357
|
* `GradApproximator` for gradient approximations
|
|
357
358
|
* `LineSearch` for line searches
|
|
358
|
-
* `Preconditioner` for preconditioners
|
|
359
359
|
* `Projection` for projections like GaLore or into fourier domain.
|
|
360
360
|
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
361
361
|
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
@@ -376,4 +376,4 @@ There are also wrappers providing `torch.optim.Optimizer` interface for for `sci
|
|
|
376
376
|
|
|
377
377
|
They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
378
378
|
|
|
379
|
-
Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
|
|
379
|
+
Apparently <https://github.com/avaneev/biteopt> is diabolical so I will add a wrapper for it too very soon.
|
|
@@ -118,13 +118,14 @@ for epoch in range(100):
|
|
|
118
118
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
119
119
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
120
120
|
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
121
|
+
* `HigherOrderNewton`: Higher order Newton's method with trust region.
|
|
121
122
|
|
|
122
123
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
123
124
|
* `LBFGS`: Limited-memory BFGS.
|
|
124
125
|
* `LSR1`: Limited-memory SR1.
|
|
125
126
|
* `OnlineLBFGS`: Online LBFGS.
|
|
126
|
-
* `BFGS`, `SR1`, `
|
|
127
|
-
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
|
|
127
|
+
* `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `ColumnUpdatingMethod`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`: Classic full-matrix quasi-newton methods.
|
|
128
|
+
* `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
|
|
128
129
|
|
|
129
130
|
* **Line Search**:
|
|
130
131
|
* `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
|
|
@@ -273,20 +274,20 @@ not in the module itself. Also both per-parameter settings and state are stored
|
|
|
273
274
|
|
|
274
275
|
```python
|
|
275
276
|
import torch
|
|
276
|
-
from torchzero.core import Module,
|
|
277
|
+
from torchzero.core import Module, Var
|
|
277
278
|
|
|
278
279
|
class HeavyBall(Module):
|
|
279
280
|
def __init__(self, momentum: float = 0.9, dampening: float = 0):
|
|
280
281
|
defaults = dict(momentum=momentum, dampening=dampening)
|
|
281
282
|
super().__init__(defaults)
|
|
282
283
|
|
|
283
|
-
def step(self,
|
|
284
|
-
# a module takes a
|
|
285
|
-
#
|
|
284
|
+
def step(self, var: Var):
|
|
285
|
+
# a module takes a Var object, modifies it or creates a new one, and returns it
|
|
286
|
+
# Var has a bunch of attributes, including parameters, gradients, update, closure, loss
|
|
286
287
|
# for now we are only interested in update, and we will apply the heavyball rule to it.
|
|
287
288
|
|
|
288
|
-
params =
|
|
289
|
-
update =
|
|
289
|
+
params = var.params
|
|
290
|
+
update = var.get_update() # list of tensors
|
|
290
291
|
|
|
291
292
|
exp_avg_list = []
|
|
292
293
|
for p, u in zip(params, update):
|
|
@@ -307,16 +308,15 @@ class HeavyBall(Module):
|
|
|
307
308
|
# and it is part of self.state
|
|
308
309
|
exp_avg_list.append(buf.clone())
|
|
309
310
|
|
|
310
|
-
# set new update to
|
|
311
|
-
|
|
312
|
-
return
|
|
311
|
+
# set new update to var
|
|
312
|
+
var.update = exp_avg_list
|
|
313
|
+
return var
|
|
313
314
|
```
|
|
314
315
|
|
|
315
316
|
There are a some specialized base modules that make it much easier to implement some specific things.
|
|
316
317
|
|
|
317
318
|
* `GradApproximator` for gradient approximations
|
|
318
319
|
* `LineSearch` for line searches
|
|
319
|
-
* `Preconditioner` for preconditioners
|
|
320
320
|
* `Projection` for projections like GaLore or into fourier domain.
|
|
321
321
|
* `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
|
|
322
322
|
* `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
|
|
@@ -337,4 +337,4 @@ There are also wrappers providing `torch.optim.Optimizer` interface for for `sci
|
|
|
337
337
|
|
|
338
338
|
They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
|
|
339
339
|
|
|
340
|
-
Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
|
|
340
|
+
Apparently <https://github.com/avaneev/biteopt> is diabolical so I will add a wrapper for it too very soon.
|
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Sanity tests to make sure everything works.
|
|
3
|
+
|
|
4
|
+
This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
|
|
5
|
+
don't error or become unhinged with different parameter shapes.
|
|
6
|
+
"""
|
|
2
7
|
from collections.abc import Callable
|
|
3
8
|
from functools import partial
|
|
4
9
|
|
|
@@ -68,6 +73,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
|
|
|
68
73
|
assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
|
|
69
74
|
losses.append(loss)
|
|
70
75
|
|
|
76
|
+
losses.append(objective())
|
|
71
77
|
return torch.stack(losses).nan_to_num(0,10000,10000).min()
|
|
72
78
|
|
|
73
79
|
def _run_func(opt_fn: Callable, func:str, merge: bool, use_closure: bool, steps: int):
|
|
@@ -524,7 +530,7 @@ PolyakStepSize = Run(
|
|
|
524
530
|
func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
|
|
525
531
|
sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
|
|
526
532
|
needs_closure=True,
|
|
527
|
-
func='booth', steps=50, loss=1e-
|
|
533
|
+
func='booth', steps=50, loss=1e-7, merge_invariant=True,
|
|
528
534
|
sphere_steps=10, sphere_loss=0.002,
|
|
529
535
|
)
|
|
530
536
|
RandomStepSize = Run(
|
|
@@ -604,44 +610,44 @@ ScaleModulesByCosineSimilarity = Run(
|
|
|
604
610
|
|
|
605
611
|
# ------------------------- momentum/matrix_momentum ------------------------- #
|
|
606
612
|
MatrixMomentum_forward = Run(
|
|
607
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
608
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
613
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
|
|
614
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
|
|
609
615
|
needs_closure=True,
|
|
610
616
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
611
617
|
sphere_steps=10, sphere_loss=0,
|
|
612
618
|
)
|
|
613
619
|
MatrixMomentum_forward = Run(
|
|
614
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
615
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
620
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
|
|
621
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
|
|
616
622
|
needs_closure=True,
|
|
617
623
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
618
624
|
sphere_steps=10, sphere_loss=0,
|
|
619
625
|
)
|
|
620
626
|
MatrixMomentum_forward = Run(
|
|
621
|
-
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
622
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(
|
|
627
|
+
func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
|
|
628
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
|
|
623
629
|
needs_closure=True,
|
|
624
630
|
func='booth', steps=50, loss=0.05, merge_invariant=True,
|
|
625
631
|
sphere_steps=10, sphere_loss=0,
|
|
626
632
|
)
|
|
627
633
|
|
|
628
634
|
AdaptiveMatrixMomentum_forward = Run(
|
|
629
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
630
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
635
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
|
|
636
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
|
|
631
637
|
needs_closure=True,
|
|
632
638
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
633
639
|
sphere_steps=10, sphere_loss=0,
|
|
634
640
|
)
|
|
635
641
|
AdaptiveMatrixMomentum_central = Run(
|
|
636
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
637
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
642
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
|
|
643
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
|
|
638
644
|
needs_closure=True,
|
|
639
645
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
640
646
|
sphere_steps=10, sphere_loss=0,
|
|
641
647
|
)
|
|
642
648
|
AdaptiveMatrixMomentum_autograd = Run(
|
|
643
|
-
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
644
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(
|
|
649
|
+
func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
|
|
650
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
|
|
645
651
|
needs_closure=True,
|
|
646
652
|
func='booth', steps=50, loss=0.002, merge_invariant=True,
|
|
647
653
|
sphere_steps=10, sphere_loss=0,
|
|
@@ -719,11 +725,11 @@ Lion = Run(
|
|
|
719
725
|
)
|
|
720
726
|
# ---------------------------- optimizers/shampoo ---------------------------- #
|
|
721
727
|
Shampoo = Run(
|
|
722
|
-
func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(
|
|
723
|
-
sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.
|
|
728
|
+
func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
|
|
729
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
|
|
724
730
|
needs_closure=False,
|
|
725
|
-
func='booth', steps=50, loss=
|
|
726
|
-
sphere_steps=20, sphere_loss=
|
|
731
|
+
func='booth', steps=50, loss=0.02, merge_invariant=False,
|
|
732
|
+
sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
|
|
727
733
|
)
|
|
728
734
|
|
|
729
735
|
# ------------------------- quasi_newton/quasi_newton ------------------------ #
|
|
@@ -745,7 +751,7 @@ SSVM = Run(
|
|
|
745
751
|
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
746
752
|
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
747
753
|
needs_closure=True,
|
|
748
|
-
func='rosen', steps=50, loss=
|
|
754
|
+
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
749
755
|
sphere_steps=10, sphere_loss=0,
|
|
750
756
|
)
|
|
751
757
|
|
|
@@ -791,7 +797,7 @@ NewtonCG = Run(
|
|
|
791
797
|
sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
|
|
792
798
|
needs_closure=True,
|
|
793
799
|
func='rosen', steps=20, loss=1e-7, merge_invariant=True,
|
|
794
|
-
sphere_steps=2, sphere_loss=
|
|
800
|
+
sphere_steps=2, sphere_loss=3e-4,
|
|
795
801
|
)
|
|
796
802
|
|
|
797
803
|
# ---------------------------- smoothing/gaussian ---------------------------- #
|
|
@@ -854,8 +860,17 @@ SophiaH = Run(
|
|
|
854
860
|
sphere_steps=10, sphere_loss=40,
|
|
855
861
|
)
|
|
856
862
|
|
|
863
|
+
# -------------------------- optimizers/higher_order ------------------------- #
|
|
864
|
+
HigherOrderNewton = Run(
|
|
865
|
+
func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
|
|
866
|
+
sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
|
|
867
|
+
needs_closure=True,
|
|
868
|
+
func='rosen', steps=1, loss=2e-10, merge_invariant=True,
|
|
869
|
+
sphere_steps=1, sphere_loss=1e-10,
|
|
870
|
+
)
|
|
871
|
+
|
|
857
872
|
# ------------------------------------ CGs ----------------------------------- #
|
|
858
|
-
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY):
|
|
873
|
+
for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
|
|
859
874
|
for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
|
|
860
875
|
# but also test 10 to make sure it doesn't explode after converging
|
|
861
876
|
Run(
|
|
@@ -868,7 +883,25 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
|
|
|
868
883
|
|
|
869
884
|
# ------------------------------- QN stability ------------------------------- #
|
|
870
885
|
# stability test
|
|
871
|
-
for QN in (
|
|
886
|
+
for QN in (
|
|
887
|
+
tz.m.BFGS,
|
|
888
|
+
tz.m.SR1,
|
|
889
|
+
tz.m.DFP,
|
|
890
|
+
tz.m.BroydenGood,
|
|
891
|
+
tz.m.BroydenBad,
|
|
892
|
+
tz.m.Greenstadt1,
|
|
893
|
+
tz.m.Greenstadt2,
|
|
894
|
+
tz.m.ColumnUpdatingMethod,
|
|
895
|
+
tz.m.ThomasOptimalMethod,
|
|
896
|
+
tz.m.FletcherVMM,
|
|
897
|
+
tz.m.Horisho,
|
|
898
|
+
lambda scale_first: tz.m.Horisho(scale_first=scale_first, inner=tz.m.GradientCorrection()),
|
|
899
|
+
tz.m.Pearson,
|
|
900
|
+
tz.m.ProjectedNewtonRaphson,
|
|
901
|
+
tz.m.PSB,
|
|
902
|
+
tz.m.McCormick,
|
|
903
|
+
tz.m.SSVM,
|
|
904
|
+
):
|
|
872
905
|
Run(
|
|
873
906
|
func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
874
907
|
sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
|
|
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
|
|
|
835
835
|
expected = vec_equiv_func()
|
|
836
836
|
|
|
837
837
|
if isinstance(result, bool): assert result == expected
|
|
838
|
-
else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
|
|
838
|
+
else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
|
|
839
839
|
|
|
840
840
|
|
|
841
841
|
def test_global_vector_norm(simple_tl: TensorList):
|
|
@@ -1261,8 +1261,8 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
|
|
|
1261
1261
|
elif reduction_method == 'quantile': expected = vec.quantile(q)
|
|
1262
1262
|
else:
|
|
1263
1263
|
pytest.fail("Unknown global reduction")
|
|
1264
|
-
assert False,
|
|
1265
|
-
assert torch.allclose(result, expected)
|
|
1264
|
+
assert False, reduction_method
|
|
1265
|
+
assert torch.allclose(result, expected, atol=1e-4)
|
|
1266
1266
|
else:
|
|
1267
1267
|
expected_list = []
|
|
1268
1268
|
for t in simple_tl:
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
import torch
|
|
3
|
-
from torchzero.core.module import
|
|
3
|
+
from torchzero.core.module import Var
|
|
4
4
|
from torchzero.utils.tensorlist import TensorList
|
|
5
5
|
|
|
6
6
|
@torch.no_grad
|
|
7
|
-
def
|
|
7
|
+
def test_var_get_loss():
|
|
8
8
|
|
|
9
9
|
# ---------------------------- test that it works ---------------------------- #
|
|
10
10
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
@@ -26,20 +26,20 @@ def test_vars_get_loss():
|
|
|
26
26
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
27
27
|
return loss
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
var = Var(params=params, closure=closure_1, model=None, current_step=0)
|
|
30
30
|
|
|
31
|
-
assert
|
|
31
|
+
assert var.loss is None, var.loss
|
|
32
32
|
|
|
33
|
-
assert (loss :=
|
|
33
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
34
34
|
assert evaluated, evaluated
|
|
35
|
-
assert loss is
|
|
36
|
-
assert
|
|
37
|
-
assert
|
|
38
|
-
assert
|
|
35
|
+
assert loss is var.loss
|
|
36
|
+
assert var.loss == 4.0
|
|
37
|
+
assert var.loss_approx == 4.0
|
|
38
|
+
assert var.grad is None, var.grad
|
|
39
39
|
|
|
40
40
|
# reevaluate, which should just return already evaluated loss
|
|
41
|
-
assert (loss :=
|
|
42
|
-
assert
|
|
41
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
42
|
+
assert var.grad is None, var.grad
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
# ----------------------- test that backward=True works ---------------------- #
|
|
@@ -61,30 +61,30 @@ def test_vars_get_loss():
|
|
|
61
61
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
62
62
|
return loss
|
|
63
63
|
|
|
64
|
-
|
|
65
|
-
assert
|
|
66
|
-
assert (loss :=
|
|
67
|
-
assert
|
|
68
|
-
assert
|
|
64
|
+
var = Var(params=params, closure=closure_2, model=None, current_step=0)
|
|
65
|
+
assert var.grad is None, var.grad
|
|
66
|
+
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
67
|
+
assert var.grad is not None
|
|
68
|
+
assert var.grad[0] == 2.0, var.grad
|
|
69
69
|
|
|
70
70
|
# reevaluate, which should just return already evaluated loss
|
|
71
|
-
assert (loss :=
|
|
72
|
-
assert
|
|
71
|
+
assert (loss := var.get_loss(backward=True)) == 6.0, loss
|
|
72
|
+
assert var.grad[0] == 2.0, var.grad
|
|
73
73
|
|
|
74
74
|
# get grad, which should just return already evaluated grad
|
|
75
|
-
assert (grad :=
|
|
76
|
-
assert grad is
|
|
75
|
+
assert (grad := var.get_grad())[0] == 2.0, grad
|
|
76
|
+
assert grad is var.grad, grad
|
|
77
77
|
|
|
78
78
|
# get update, which should create and return cloned grad
|
|
79
|
-
assert
|
|
80
|
-
assert (update :=
|
|
81
|
-
assert update is
|
|
82
|
-
assert update is not
|
|
83
|
-
assert
|
|
84
|
-
assert update[0] ==
|
|
79
|
+
assert var.update is None
|
|
80
|
+
assert (update := var.get_update())[0] == 2.0, update
|
|
81
|
+
assert update is var.update
|
|
82
|
+
assert update is not var.grad
|
|
83
|
+
assert var.grad is not None
|
|
84
|
+
assert update[0] == var.grad[0]
|
|
85
85
|
|
|
86
86
|
@torch.no_grad
|
|
87
|
-
def
|
|
87
|
+
def test_var_get_grad():
|
|
88
88
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
89
89
|
evaluated = False
|
|
90
90
|
|
|
@@ -103,20 +103,20 @@ def test_vars_get_grad():
|
|
|
103
103
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
104
104
|
return loss
|
|
105
105
|
|
|
106
|
-
|
|
107
|
-
assert (grad :=
|
|
108
|
-
assert grad is
|
|
106
|
+
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
107
|
+
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
108
|
+
assert grad is var.grad
|
|
109
109
|
|
|
110
|
-
assert
|
|
111
|
-
assert (loss :=
|
|
112
|
-
assert (loss :=
|
|
113
|
-
assert
|
|
110
|
+
assert var.loss == 4.0
|
|
111
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
112
|
+
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
113
|
+
assert var.loss_approx == 4.0
|
|
114
114
|
|
|
115
|
-
assert
|
|
116
|
-
assert (update :=
|
|
115
|
+
assert var.update is None, var.update
|
|
116
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
117
117
|
|
|
118
118
|
@torch.no_grad
|
|
119
|
-
def
|
|
119
|
+
def test_var_get_update():
|
|
120
120
|
params = [torch.tensor(2.0, requires_grad=True)]
|
|
121
121
|
evaluated = False
|
|
122
122
|
|
|
@@ -135,24 +135,24 @@ def test_vars_get_update():
|
|
|
135
135
|
assert not loss.requires_grad, "loss requires grad with backward=False"
|
|
136
136
|
return loss
|
|
137
137
|
|
|
138
|
-
|
|
139
|
-
assert
|
|
140
|
-
assert (update :=
|
|
141
|
-
assert update is
|
|
138
|
+
var = Var(params=params, closure=closure, model=None, current_step=0)
|
|
139
|
+
assert var.update is None, var.update
|
|
140
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
141
|
+
assert update is var.update
|
|
142
142
|
|
|
143
|
-
assert (grad :=
|
|
144
|
-
assert grad is
|
|
143
|
+
assert (grad := var.get_grad())[0] == 4.0, grad
|
|
144
|
+
assert grad is var.grad
|
|
145
145
|
assert grad is not update
|
|
146
146
|
|
|
147
|
-
assert
|
|
148
|
-
assert (loss :=
|
|
149
|
-
assert (loss :=
|
|
150
|
-
assert
|
|
147
|
+
assert var.loss == 4.0
|
|
148
|
+
assert (loss := var.get_loss(backward=False)) == 4.0, loss
|
|
149
|
+
assert (loss := var.get_loss(backward=True)) == 4.0, loss
|
|
150
|
+
assert var.loss_approx == 4.0
|
|
151
151
|
|
|
152
|
-
assert (update :=
|
|
152
|
+
assert (update := var.get_update())[0] == 4.0, update
|
|
153
153
|
|
|
154
154
|
|
|
155
|
-
def
|
|
155
|
+
def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
|
|
156
156
|
for k,v in v1.__dict__.items():
|
|
157
157
|
if not k.startswith('__'):
|
|
158
158
|
# if k == 'post_step_hooks': continue
|
|
@@ -165,20 +165,20 @@ def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
|
|
|
165
165
|
else:
|
|
166
166
|
assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
|
|
167
167
|
|
|
168
|
-
def
|
|
168
|
+
def test_var_clone():
|
|
169
169
|
model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
|
|
170
170
|
def closure(backward): return 1
|
|
171
|
-
|
|
171
|
+
var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
|
|
172
172
|
|
|
173
|
-
|
|
174
|
-
|
|
173
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
174
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
175
175
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
176
|
+
var.grad = TensorList(torch.randn(5))
|
|
177
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
178
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|
|
179
179
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
180
|
+
var.update = TensorList(torch.randn(5) * 2)
|
|
181
|
+
var.loss = torch.randn(1)
|
|
182
|
+
var.loss_approx = var.loss
|
|
183
|
+
_assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
|
|
184
|
+
_assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
|