torchzero 0.3.5__tar.gz → 0.3.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.3.5 → torchzero-0.3.8}/PKG-INFO +2 -2
- {torchzero-0.3.5 → torchzero-0.3.8}/README.md +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/pyproject.toml +2 -2
- {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_opts.py +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_tensorlist.py +17 -17
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/preconditioner.py +11 -10
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/__init__.py +3 -2
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/absoap.py +8 -2
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/adadam.py +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/adamY.py +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/adasoap.py +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/algebraic_newton.py +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/curveball.py +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/gradmin.py +1 -1
- torchzero-0.3.8/torchzero/modules/experimental/newton_solver.py +88 -0
- torchzero-0.3.5/torchzero/modules/experimental/dsoap.py → torchzero-0.3.8/torchzero/modules/experimental/soapy.py +4 -4
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/spectral.py +5 -3
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/subspace_preconditioners.py +16 -9
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/soap.py +1 -2
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/projection.py +27 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/lbfgs.py +4 -3
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/lsr1.py +6 -3
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/quasi_newton.py +16 -17
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/__init__.py +1 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/newton_cg.py +1 -1
- torchzero-0.3.8/torchzero/utils/linalg/benchmark.py +20 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/solve.py +15 -14
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/PKG-INFO +2 -2
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/SOURCES.txt +3 -1
- {torchzero-0.3.5 → torchzero-0.3.8}/LICENSE +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/docs/source/conf.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/setup.cfg +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_identical.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_module.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/tests/test_vars.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/module.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/core/transform.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/ema_clipping.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/clipping/growth_clipping.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/tropical_newton.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/functional.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/grad_approximation/rfdm.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/scipy.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/strong_wolfe.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/line_search/trust_region.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/lr/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/lr/lr.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/lr/step_size.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/averaging.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/ema.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/experimental.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/matrix_momentum.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/momentum/momentum.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/accumulate.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/debug.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/misc.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/split.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/switch.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/adagrad.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/adam.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/lion.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/muon.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/orthograd.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/rmsprop.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/rprop.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/shampoo.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/optimizers/sophia_h.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/dct.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/fft.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/projections/structural.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/cg.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/olbfgs.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/newton.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/second_order/nystrom.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/smoothing/gaussian.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/weight_decay/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/weight_decay/weight_decay.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/optim/wrappers/scipy.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/compile.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/__init__.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/matrix_funcs.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/orthogonalize.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/qr.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/linalg/svd.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/ops.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/params.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/python_tools.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.3.5 → torchzero-0.3.8}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.8
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -156,7 +156,7 @@ for epoch in range(100):
|
|
|
156
156
|
* `Newton`: Classic Newton's method.
|
|
157
157
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
158
158
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning (
|
|
159
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
160
160
|
|
|
161
161
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
162
|
* `LBFGS`: Limited-memory BFGS.
|
|
@@ -117,7 +117,7 @@ for epoch in range(100):
|
|
|
117
117
|
* `Newton`: Classic Newton's method.
|
|
118
118
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
119
119
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
120
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning (
|
|
120
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
121
121
|
|
|
122
122
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
123
123
|
* `LBFGS`: Limited-memory BFGS.
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# STEP 1 - COMMIT NEW CHANGES BUT DON'T PUSH THEM YET
|
|
3
3
|
# STEP 2 - BUMP VERSION AND COMMIT IT (DONT PUSH!!!!)
|
|
4
4
|
# STEP 3 - CREATE TAG WITH THAT VERSION
|
|
5
|
-
# STEP 4 - PUSH CHANGES
|
|
5
|
+
# STEP 4 - PUSH (SYNC) CHANGES
|
|
6
6
|
# STEP 5 - PUSH TAG
|
|
7
7
|
|
|
8
8
|
[build-system]
|
|
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
|
|
|
13
13
|
name = "torchzero"
|
|
14
14
|
description = "Modular optimization library for PyTorch."
|
|
15
15
|
|
|
16
|
-
version = "0.3.
|
|
16
|
+
version = "0.3.8"
|
|
17
17
|
dependencies = [
|
|
18
18
|
"torch",
|
|
19
19
|
"numpy",
|
|
@@ -745,7 +745,7 @@ SSVM = Run(
|
|
|
745
745
|
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
746
746
|
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
747
747
|
needs_closure=True,
|
|
748
|
-
func='rosen', steps=50, loss=
|
|
748
|
+
func='rosen', steps=50, loss=0.02, merge_invariant=True,
|
|
749
749
|
sphere_steps=10, sphere_loss=0,
|
|
750
750
|
)
|
|
751
751
|
|
|
@@ -1301,7 +1301,7 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
|
|
|
1301
1301
|
expected_tl = TensorList(expected_list)
|
|
1302
1302
|
assert isinstance(result, TensorList)
|
|
1303
1303
|
assert len(result) == len(expected_tl)
|
|
1304
|
-
assert_tl_allclose(result, expected_tl, atol=1e-
|
|
1304
|
+
assert_tl_allclose(result, expected_tl, atol=1e-3) # Use allclose due to potential float variations
|
|
1305
1305
|
|
|
1306
1306
|
# --- Grafting, Rescaling, Normalizing, Clipping ---
|
|
1307
1307
|
|
|
@@ -1381,8 +1381,8 @@ def test_rescale(simple_tl: TensorList, dim):
|
|
|
1381
1381
|
assert torch.allclose(rescaled_scalar.global_min(), torch.tensor(min_val))
|
|
1382
1382
|
assert torch.allclose(rescaled_scalar.global_max(), torch.tensor(max_val))
|
|
1383
1383
|
else:
|
|
1384
|
-
assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-
|
|
1385
|
-
assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-
|
|
1384
|
+
assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-3)
|
|
1385
|
+
assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-3)
|
|
1386
1386
|
|
|
1387
1387
|
|
|
1388
1388
|
# Rescale list
|
|
@@ -1402,8 +1402,8 @@ def test_rescale(simple_tl: TensorList, dim):
|
|
|
1402
1402
|
assert global_max_rescaled < avg_max + 1.0 # Loose check
|
|
1403
1403
|
|
|
1404
1404
|
else:
|
|
1405
|
-
assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-
|
|
1406
|
-
assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-
|
|
1405
|
+
assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-3)
|
|
1406
|
+
assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-3)
|
|
1407
1407
|
|
|
1408
1408
|
# Rescale to 01 helper
|
|
1409
1409
|
rescaled_01 = simple_tl.rescale_to_01(dim=dim, eps=eps)
|
|
@@ -1413,8 +1413,8 @@ def test_rescale(simple_tl: TensorList, dim):
|
|
|
1413
1413
|
assert torch.allclose(rescaled_01.global_min(), torch.tensor(0.0))
|
|
1414
1414
|
assert torch.allclose(rescaled_01.global_max(), torch.tensor(1.0))
|
|
1415
1415
|
else:
|
|
1416
|
-
assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-
|
|
1417
|
-
assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-
|
|
1416
|
+
assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-3)
|
|
1417
|
+
assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-3)
|
|
1418
1418
|
|
|
1419
1419
|
|
|
1420
1420
|
# Test inplace
|
|
@@ -1454,11 +1454,11 @@ def test_normalize(big_tl: TensorList, dim):
|
|
|
1454
1454
|
normalized_scalar_var = normalized_scalar.var(dim=dim if dim != 'global' else None)
|
|
1455
1455
|
|
|
1456
1456
|
if dim == 'global':
|
|
1457
|
-
assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-
|
|
1458
|
-
assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-
|
|
1457
|
+
assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-3)
|
|
1458
|
+
assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-3)
|
|
1459
1459
|
else:
|
|
1460
|
-
assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-
|
|
1461
|
-
assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-
|
|
1460
|
+
assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-3)
|
|
1461
|
+
assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-3)
|
|
1462
1462
|
|
|
1463
1463
|
# Normalize list mean/var
|
|
1464
1464
|
normalized_list = simple_tl.normalize(mean_list, var_list, dim=dim)
|
|
@@ -1476,19 +1476,19 @@ def test_normalize(big_tl: TensorList, dim):
|
|
|
1476
1476
|
# assert torch.allclose(global_mean_rescaled, torch.tensor(avg_mean), rtol=1e-1, atol=1e-1) # Loose check
|
|
1477
1477
|
# assert torch.allclose(global_var_rescaled, torch.tensor(avg_var), rtol=1e-1, atol=1e-1) # Loose check
|
|
1478
1478
|
else:
|
|
1479
|
-
assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-
|
|
1480
|
-
assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-
|
|
1479
|
+
assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-3)
|
|
1480
|
+
assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-3)
|
|
1481
1481
|
|
|
1482
1482
|
# Z-normalize helper
|
|
1483
1483
|
znorm = simple_tl.znormalize(dim=dim, eps=1e-10)
|
|
1484
1484
|
znorm_mean = znorm.mean(dim=dim if dim != 'global' else None)
|
|
1485
1485
|
znorm_var = znorm.var(dim=dim if dim != 'global' else None)
|
|
1486
1486
|
if dim == 'global':
|
|
1487
|
-
assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-
|
|
1488
|
-
assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-
|
|
1487
|
+
assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-3)
|
|
1488
|
+
assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-3)
|
|
1489
1489
|
else:
|
|
1490
|
-
assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-
|
|
1491
|
-
assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-
|
|
1490
|
+
assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-3)
|
|
1491
|
+
assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-3)
|
|
1492
1492
|
|
|
1493
1493
|
|
|
1494
1494
|
# Test inplace
|
|
@@ -38,7 +38,7 @@ class Preconditioner(Transform):
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
41
|
-
step = self.global_state.get('
|
|
41
|
+
step = self.global_state.get('__step', 0)
|
|
42
42
|
states = [self.state[p] for p in params]
|
|
43
43
|
settings = [self.settings[p] for p in params]
|
|
44
44
|
global_settings = settings[0]
|
|
@@ -47,8 +47,10 @@ class Preconditioner(Transform):
|
|
|
47
47
|
scale_first = global_settings['__scale_first']
|
|
48
48
|
scale_factor = 0
|
|
49
49
|
if scale_first and step == 0:
|
|
50
|
-
# initial step size guess from pytorch LBFGS
|
|
51
|
-
|
|
50
|
+
# initial step size guess from pytorch LBFGS was too unstable
|
|
51
|
+
# I switched to norm
|
|
52
|
+
tensors = TensorList(tensors)
|
|
53
|
+
scale_factor = tensors.abs().global_mean().clip(min=1)
|
|
52
54
|
|
|
53
55
|
# update preconditioner
|
|
54
56
|
if step % update_freq == 0:
|
|
@@ -65,11 +67,11 @@ class Preconditioner(Transform):
|
|
|
65
67
|
if scale_first and step == 0:
|
|
66
68
|
torch._foreach_div_(tensors, scale_factor)
|
|
67
69
|
|
|
68
|
-
self.global_state['
|
|
70
|
+
self.global_state['__step'] = step + 1
|
|
69
71
|
return tensors
|
|
70
72
|
|
|
71
73
|
def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
72
|
-
step = self.global_state.get('
|
|
74
|
+
step = self.global_state.get('__step', 0)
|
|
73
75
|
tensors_vec = torch.cat([t.ravel() for t in tensors])
|
|
74
76
|
params_vec = torch.cat([p.ravel() for p in params])
|
|
75
77
|
grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
@@ -82,8 +84,8 @@ class Preconditioner(Transform):
|
|
|
82
84
|
scale_first = global_settings['__scale_first']
|
|
83
85
|
scale_factor = 0
|
|
84
86
|
if scale_first and step == 0:
|
|
85
|
-
# initial step size guess from pytorch LBFGS
|
|
86
|
-
scale_factor = tensors_vec.abs().
|
|
87
|
+
# initial step size guess from pytorch LBFGS was too unstable
|
|
88
|
+
scale_factor = tensors_vec.abs().mean().clip(min=1)
|
|
87
89
|
|
|
88
90
|
# update preconditioner
|
|
89
91
|
if step % update_freq == 0:
|
|
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
|
|
|
99
101
|
|
|
100
102
|
# scale initial step, when preconditioner might not have been applied
|
|
101
103
|
if scale_first and step == 0:
|
|
102
|
-
|
|
103
|
-
tensors_vec /= scale_factor
|
|
104
|
+
tensors_vec /= scale_factor
|
|
104
105
|
|
|
105
106
|
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
106
|
-
self.global_state['
|
|
107
|
+
self.global_state['__step'] = step + 1
|
|
107
108
|
return tensors
|
|
108
109
|
|
|
109
110
|
@torch.no_grad
|
|
@@ -3,7 +3,7 @@ from .adadam import Adadam
|
|
|
3
3
|
from .adamY import AdamY
|
|
4
4
|
from .adasoap import AdaSOAP
|
|
5
5
|
from .curveball import CurveBall
|
|
6
|
-
from .
|
|
6
|
+
from .soapy import SOAPY
|
|
7
7
|
from .gradmin import GradMin
|
|
8
8
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
9
9
|
from .spectral import SpectralPreconditioner
|
|
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
|
|
|
11
11
|
HistorySubspacePreconditioning,
|
|
12
12
|
RandomSubspacePreconditioning,
|
|
13
13
|
)
|
|
14
|
-
from .tropical_newton import TropicalNewton
|
|
14
|
+
from .tropical_newton import TropicalNewton
|
|
15
|
+
from .newton_solver import NewtonSolver
|
|
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
|
|
|
140
140
|
class ABSOAP(Transform):
|
|
141
141
|
"""SOAP but with two extra letters included in its name in order to improve converence
|
|
142
142
|
|
|
143
|
+
so what you can do is choose what goes into what ,and that is supposed to be good.
|
|
144
|
+
|
|
143
145
|
new args
|
|
144
146
|
|
|
145
147
|
scale by s whether to scale gradient differences by parameter differences
|
|
146
148
|
|
|
147
149
|
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
150
|
+
|
|
151
|
+
okay I changed these args into another ones
|
|
152
|
+
|
|
153
|
+
BASICALLY THIS IS FOR MY EXPERIMENTS
|
|
148
154
|
"""
|
|
149
155
|
def __init__(
|
|
150
156
|
self,
|
|
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
|
|
|
213
219
|
if 'g_prev' not in state:
|
|
214
220
|
state['p_prev'] = p.clone()
|
|
215
221
|
state['g_prev'] = t.clone()
|
|
216
|
-
updates.append(tensors[i].
|
|
222
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
217
223
|
continue
|
|
218
224
|
|
|
219
225
|
p_prev = state['p_prev']
|
|
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
|
|
|
285
291
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
286
292
|
|
|
287
293
|
state['step'] = 0
|
|
288
|
-
updates.append(tensors[i].
|
|
294
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
289
295
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
290
296
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
291
297
|
|
|
@@ -218,7 +218,7 @@ class AdaSOAP(Transform):
|
|
|
218
218
|
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
219
219
|
|
|
220
220
|
state['step'] = 0
|
|
221
|
-
updates.append(tensors[i].
|
|
221
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
222
222
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
223
223
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
224
224
|
|
|
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
class AlgebraicNewton(Module):
|
|
74
|
-
"""newton in other algebras, not
|
|
74
|
+
"""newton in other algebras, not that it works."""
|
|
75
75
|
def __init__(
|
|
76
76
|
self,
|
|
77
77
|
reg: float | None = None,
|
|
@@ -13,7 +13,7 @@ def curveball(
|
|
|
13
13
|
momentum: float | NumberList,
|
|
14
14
|
precond_lr: float | NumberList,
|
|
15
15
|
):
|
|
16
|
-
"""returns z_, clone it!!!"""
|
|
16
|
+
"""returns z_, clone it!!! (no just negate it)"""
|
|
17
17
|
delta = Hz + tensors
|
|
18
18
|
z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
|
|
19
19
|
return z_
|
|
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class GradMin(Reformulation):
|
|
17
|
-
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd.
|
|
17
|
+
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
20
|
loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
from typing import Any, Literal, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, apply, Modular
|
|
7
|
+
from ...utils import TensorList, as_tensorlist
|
|
8
|
+
from ...utils.derivatives import hvp
|
|
9
|
+
from ..quasi_newton import LBFGS
|
|
10
|
+
|
|
11
|
+
class NewtonSolver(Module):
|
|
12
|
+
"""Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
16
|
+
maxiter=None,
|
|
17
|
+
tol=1e-3,
|
|
18
|
+
reg: float = 0,
|
|
19
|
+
warm_start=True,
|
|
20
|
+
inner: Chainable | None = None,
|
|
21
|
+
):
|
|
22
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
|
|
23
|
+
super().__init__(defaults,)
|
|
24
|
+
|
|
25
|
+
if inner is not None:
|
|
26
|
+
self.set_child('inner', inner)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, vars):
|
|
30
|
+
params = TensorList(vars.params)
|
|
31
|
+
closure = vars.closure
|
|
32
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
33
|
+
|
|
34
|
+
settings = self.settings[params[0]]
|
|
35
|
+
solver_cls = settings['solver']
|
|
36
|
+
maxiter = settings['maxiter']
|
|
37
|
+
tol = settings['tol']
|
|
38
|
+
reg = settings['reg']
|
|
39
|
+
warm_start = settings['warm_start']
|
|
40
|
+
|
|
41
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
+
grad = vars.get_grad(create_graph=True)
|
|
43
|
+
|
|
44
|
+
def H_mm(x):
|
|
45
|
+
with torch.enable_grad():
|
|
46
|
+
Hvp = TensorList(hvp(params, grad, x, create_graph=True))
|
|
47
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
48
|
+
return Hvp
|
|
49
|
+
|
|
50
|
+
# -------------------------------- inner step -------------------------------- #
|
|
51
|
+
b = as_tensorlist(grad)
|
|
52
|
+
if 'inner' in self.children:
|
|
53
|
+
b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
|
|
54
|
+
|
|
55
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
56
|
+
x0 = None
|
|
57
|
+
if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
|
|
58
|
+
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
|
+
else: x = x0.clone().requires_grad_(True)
|
|
60
|
+
|
|
61
|
+
solver = solver_cls(x)
|
|
62
|
+
def lstsq_closure(backward=True):
|
|
63
|
+
Hx = H_mm(x)
|
|
64
|
+
loss = (Hx-b).pow(2).global_mean()
|
|
65
|
+
if backward:
|
|
66
|
+
solver.zero_grad()
|
|
67
|
+
loss.backward(inputs=x)
|
|
68
|
+
return loss
|
|
69
|
+
|
|
70
|
+
if maxiter is None: maxiter = b.global_numel()
|
|
71
|
+
loss = None
|
|
72
|
+
initial_loss = lstsq_closure(False)
|
|
73
|
+
if initial_loss > tol:
|
|
74
|
+
for i in range(maxiter):
|
|
75
|
+
loss = solver.step(lstsq_closure)
|
|
76
|
+
assert loss is not None
|
|
77
|
+
if min(loss, loss/initial_loss) < tol: break
|
|
78
|
+
|
|
79
|
+
print(f'{loss = }')
|
|
80
|
+
|
|
81
|
+
if warm_start:
|
|
82
|
+
assert x0 is not None
|
|
83
|
+
x0.copy_(x)
|
|
84
|
+
|
|
85
|
+
vars.update = x.detach()
|
|
86
|
+
return vars
|
|
87
|
+
|
|
88
|
+
|
|
@@ -3,7 +3,7 @@ from operator import itemgetter
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import Chainable, Transform, apply
|
|
6
|
-
from
|
|
6
|
+
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
7
|
|
|
8
8
|
@torch.no_grad
|
|
9
9
|
def update_soap_covariances_(
|
|
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
135
135
|
|
|
136
136
|
return final, exp_avg_sq
|
|
137
137
|
|
|
138
|
-
class
|
|
138
|
+
class SOAPY(Transform):
|
|
139
139
|
"""SOAP but uses scaled gradient differences
|
|
140
140
|
|
|
141
141
|
new args
|
|
@@ -195,7 +195,7 @@ class DSOAP(Transform):
|
|
|
195
195
|
if 'g_prev' not in state:
|
|
196
196
|
state['p_prev'] = p.clone()
|
|
197
197
|
state['g_prev'] = t.clone()
|
|
198
|
-
updates.append(tensors[i].
|
|
198
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
199
199
|
continue
|
|
200
200
|
|
|
201
201
|
p_prev = state['p_prev']
|
|
@@ -228,7 +228,7 @@ class DSOAP(Transform):
|
|
|
228
228
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
229
229
|
|
|
230
230
|
state['step'] = 0
|
|
231
|
-
updates.append(tensors[i].
|
|
231
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
232
232
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
233
233
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
234
234
|
|
|
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
194
194
|
order (int, optional):
|
|
195
195
|
whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
|
|
196
196
|
solver (str, optional): what to use for whitening. Defaults to 'svd'.
|
|
197
|
-
|
|
198
|
-
|
|
197
|
+
A_beta (float | None, optional):
|
|
198
|
+
beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
|
+
B_beta (float | None, optional):
|
|
200
|
+
beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
201
|
interval (int, optional): How often to update history. Defaults to 1 (every step).
|
|
200
202
|
concat_params (bool, optional):
|
|
201
203
|
whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
|
|
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
275
277
|
A = state.get('A', None)
|
|
276
278
|
if A is None:
|
|
277
279
|
# make a conservative step to avoid issues due to different GD scaling
|
|
278
|
-
return tensor.
|
|
280
|
+
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
279
281
|
|
|
280
282
|
B = state['B']
|
|
281
283
|
update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
|
{torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/experimental/subspace_preconditioners.py
RENAMED
|
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
|
|
|
38
38
|
return basis @ update_projected # d
|
|
39
39
|
|
|
40
40
|
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""full matrix rmsprop in random subspace"""
|
|
42
|
-
def __init__(self, k: int, beta: float | None = 0.99):
|
|
43
|
-
defaults = dict(k=k, beta=beta)
|
|
41
|
+
"""full matrix rmsprop in random slowly changing subspace"""
|
|
42
|
+
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
|
+
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
44
|
super().__init__(defaults, uses_grad=False)
|
|
45
45
|
|
|
46
|
+
if inner is not None: self.set_child('inner', inner)
|
|
47
|
+
|
|
46
48
|
def transform(self, tensors, params, grads, vars):
|
|
47
49
|
settings = self.settings[params[0]]
|
|
48
50
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
49
51
|
k = settings['k']
|
|
50
52
|
beta = settings['beta']
|
|
53
|
+
basis_beta = settings['basis_beta']
|
|
51
54
|
|
|
52
55
|
if 'basis' not in self.global_state:
|
|
53
56
|
self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
56
59
|
basis = self.global_state['basis']
|
|
57
60
|
accumulator = self.global_state['accumulator']
|
|
58
61
|
|
|
62
|
+
if basis_beta is not None:
|
|
63
|
+
basis.lerp_(torch.randn_like(basis), 1-basis_beta)
|
|
64
|
+
|
|
59
65
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
66
|
+
|
|
67
|
+
if 'inner' in self.children:
|
|
68
|
+
tensors = apply(self.children['inner'], tensors, params, grads, vars)
|
|
69
|
+
g = torch.cat([t.view(-1) for t in tensors])
|
|
70
|
+
|
|
60
71
|
try:
|
|
61
72
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
62
73
|
except torch.linalg.LinAlgError:
|
|
63
|
-
|
|
64
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
65
|
-
preconditioned = g / g.abs().sum()
|
|
74
|
+
preconditioned = g.clip(-0.1, 0.1)
|
|
66
75
|
vec_to_tensors_(preconditioned, tensors)
|
|
67
76
|
|
|
68
77
|
return tensors
|
|
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
119
128
|
try:
|
|
120
129
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
121
130
|
except torch.linalg.LinAlgError:
|
|
122
|
-
|
|
123
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
124
|
-
preconditioned = g / g.abs().sum()
|
|
131
|
+
preconditioned = g.clip(-0.1,0.1)
|
|
125
132
|
vec_to_tensors_(preconditioned, tensors)
|
|
126
133
|
|
|
127
134
|
return tensors
|
|
@@ -222,8 +222,7 @@ class SOAP(Transform):
|
|
|
222
222
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
223
223
|
|
|
224
224
|
state['step'] = 0
|
|
225
|
-
updates.append(tensors[i].
|
|
226
|
-
# updates.append(tensors[i] / tensors[i].abs().sum())
|
|
225
|
+
updates.append(tensors[i].clip(-0.1, 0.1))
|
|
227
226
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
228
227
|
# I use scaled update instead as to not mess up with next modules.
|
|
229
228
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from functools import partial
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import Iterable
|
|
4
5
|
from typing import Any, Literal
|
|
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
|
|
|
33
34
|
|
|
34
35
|
return projected_closure
|
|
35
36
|
|
|
37
|
+
def _projected_get_grad_override(
|
|
38
|
+
retain_graph: bool | None = None,
|
|
39
|
+
create_graph: bool = False,
|
|
40
|
+
projection: Any = ...,
|
|
41
|
+
unprojected_vars: Any = ...,
|
|
42
|
+
self: Any = ...,
|
|
43
|
+
):
|
|
44
|
+
assert isinstance(projection, Projection)
|
|
45
|
+
assert isinstance(unprojected_vars, Vars)
|
|
46
|
+
assert isinstance(self, Vars)
|
|
47
|
+
|
|
48
|
+
if self.grad is not None: return self.grad
|
|
49
|
+
grads = unprojected_vars.get_grad(retain_graph, create_graph)
|
|
50
|
+
projected_grads = list(projection.project(grads, self, current='grads'))
|
|
51
|
+
self.grad = projected_grads
|
|
52
|
+
for p, g in zip(self.params, projected_grads):
|
|
53
|
+
p.grad = g
|
|
54
|
+
return self.grad
|
|
55
|
+
|
|
36
56
|
|
|
37
57
|
class Projection(Module, ABC):
|
|
38
58
|
"""
|
|
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
|
|
|
137
157
|
|
|
138
158
|
# step
|
|
139
159
|
projected_vars.params = self._projected_params
|
|
160
|
+
projected_vars.get_grad = partial(
|
|
161
|
+
_projected_get_grad_override,
|
|
162
|
+
projection=self,
|
|
163
|
+
unprojected_vars=vars,
|
|
164
|
+
self=projected_vars,
|
|
165
|
+
)
|
|
140
166
|
projected_vars = self.children['modules'].step(projected_vars)
|
|
141
167
|
|
|
142
168
|
# empty fake params storage
|
|
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
|
|
|
149
175
|
unprojected_vars = projected_vars.clone(clone_update=False)
|
|
150
176
|
unprojected_vars.closure = vars.closure
|
|
151
177
|
unprojected_vars.params = vars.params
|
|
152
|
-
|
|
178
|
+
unprojected_vars.grad = vars.grad
|
|
153
179
|
|
|
154
180
|
if self._project_update:
|
|
155
181
|
assert projected_vars.update is not None
|
{torchzero-0.3.5 → torchzero-0.3.8}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py
RENAMED
|
@@ -37,10 +37,11 @@ def lbfgs(
|
|
|
37
37
|
z_tfm: Any,
|
|
38
38
|
):
|
|
39
39
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
40
|
-
# dir = params.grad.sign() # may work fine
|
|
41
40
|
|
|
42
|
-
# initial step size guess
|
|
43
|
-
|
|
41
|
+
# initial step size guess modified from pytorch L-BFGS
|
|
42
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
43
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
44
|
+
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
44
45
|
|
|
45
46
|
else:
|
|
46
47
|
# 1st loop
|
|
@@ -36,10 +36,11 @@ def lbfgs(
|
|
|
36
36
|
step: int,
|
|
37
37
|
):
|
|
38
38
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
39
|
-
# dir = params.grad.sign() # may work fine
|
|
40
39
|
|
|
41
|
-
# initial step size guess
|
|
42
|
-
|
|
40
|
+
# initial step size guess modified from pytorch L-BFGS
|
|
41
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
42
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
43
|
+
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
43
44
|
|
|
44
45
|
else:
|
|
45
46
|
# 1st loop
|
|
@@ -17,8 +17,9 @@ def lsr1_(
|
|
|
17
17
|
):
|
|
18
18
|
if step == 0 or not s_history:
|
|
19
19
|
# initial step size guess from pytorch
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
21
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
22
|
+
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
22
23
|
|
|
23
24
|
m = len(s_history)
|
|
24
25
|
|
|
@@ -64,7 +65,9 @@ def lsr1_(
|
|
|
64
65
|
Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
|
|
65
66
|
|
|
66
67
|
if scale_second and step == 1:
|
|
67
|
-
|
|
68
|
+
scale = 1 / tensors_.abs().global_sum()
|
|
69
|
+
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
70
|
+
Hx.mul_(min(1.0, scale)) # pyright:ignore[reportArgumentType]
|
|
68
71
|
return Hx
|
|
69
72
|
|
|
70
73
|
|