torchzero 0.3.8__tar.gz → 0.3.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.3.8 → torchzero-0.3.9}/PKG-INFO +1 -1
- {torchzero-0.3.8 → torchzero-0.3.9}/pyproject.toml +1 -1
- {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_opts.py +1 -1
- {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_tensorlist.py +1 -1
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/preconditioner.py +10 -10
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/adasoap.py +1 -1
- torchzero-0.3.9/torchzero/modules/experimental/structured_newton.py +111 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/cg.py +9 -9
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lbfgs.py +3 -3
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lsr1.py +7 -6
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/quasi_newton.py +3 -1
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/newton.py +11 -6
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/newton_cg.py +2 -2
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/nystrom.py +6 -6
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/PKG-INFO +1 -1
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/SOURCES.txt +1 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/LICENSE +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/README.md +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/docs/source/conf.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/setup.cfg +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_identical.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_module.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/tests/test_vars.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/module.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/core/transform.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/ema_clipping.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/clipping/growth_clipping.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/absoap.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/adadam.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/adamY.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/algebraic_newton.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/curveball.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/gradmin.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/newton_solver.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/soapy.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/spectral.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/subspace_preconditioners.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/tropical_newton.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/functional.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/rfdm.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/scipy.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/strong_wolfe.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/line_search/trust_region.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/lr/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/lr/lr.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/lr/step_size.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/averaging.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/ema.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/experimental.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/matrix_momentum.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/momentum/momentum.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/accumulate.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/debug.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/misc.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/split.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/switch.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/adagrad.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/adam.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/lion.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/muon.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/orthograd.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/rmsprop.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/rprop.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/shampoo.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/soap.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/optimizers/sophia_h.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/dct.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/fft.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/projection.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/projections/structural.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/olbfgs.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/second_order/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/smoothing/gaussian.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/weight_decay/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/weight_decay/weight_decay.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/optim/wrappers/scipy.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/compile.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/__init__.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/benchmark.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/matrix_funcs.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/orthogonalize.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/qr.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/solve.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/linalg/svd.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/ops.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/params.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/python_tools.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.3.8 → torchzero-0.3.9}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -745,7 +745,7 @@ SSVM = Run(
|
|
|
745
745
|
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
746
746
|
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
747
747
|
needs_closure=True,
|
|
748
|
-
func='rosen', steps=50, loss=
|
|
748
|
+
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
749
749
|
sphere_steps=10, sphere_loss=0,
|
|
750
750
|
)
|
|
751
751
|
|
|
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
|
|
|
835
835
|
expected = vec_equiv_func()
|
|
836
836
|
|
|
837
837
|
if isinstance(result, bool): assert result == expected
|
|
838
|
-
else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
|
|
838
|
+
else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
|
|
839
839
|
|
|
840
840
|
|
|
841
841
|
def test_global_vector_norm(simple_tl: TensorList):
|
|
@@ -45,12 +45,11 @@ class Preconditioner(Transform):
|
|
|
45
45
|
update_freq = global_settings['__update_freq']
|
|
46
46
|
|
|
47
47
|
scale_first = global_settings['__scale_first']
|
|
48
|
-
scale_factor =
|
|
48
|
+
scale_factor = 1
|
|
49
49
|
if scale_first and step == 0:
|
|
50
|
-
# initial step size guess from pytorch LBFGS
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
scale_factor = tensors.abs().global_mean().clip(min=1)
|
|
50
|
+
# initial step size guess from pytorch LBFGS
|
|
51
|
+
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
52
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
54
53
|
|
|
55
54
|
# update preconditioner
|
|
56
55
|
if step % update_freq == 0:
|
|
@@ -65,7 +64,7 @@ class Preconditioner(Transform):
|
|
|
65
64
|
|
|
66
65
|
# scale initial step, when preconditioner might not have been applied
|
|
67
66
|
if scale_first and step == 0:
|
|
68
|
-
torch.
|
|
67
|
+
torch._foreach_mul_(tensors, scale_factor)
|
|
69
68
|
|
|
70
69
|
self.global_state['__step'] = step + 1
|
|
71
70
|
return tensors
|
|
@@ -82,10 +81,11 @@ class Preconditioner(Transform):
|
|
|
82
81
|
update_freq = global_settings['__update_freq']
|
|
83
82
|
|
|
84
83
|
scale_first = global_settings['__scale_first']
|
|
85
|
-
scale_factor =
|
|
84
|
+
scale_factor = 1
|
|
86
85
|
if scale_first and step == 0:
|
|
87
|
-
# initial step size guess from pytorch LBFGS
|
|
88
|
-
scale_factor = tensors_vec.abs().
|
|
86
|
+
# initial step size guess from pytorch LBFGS
|
|
87
|
+
scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
|
|
88
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
|
|
89
89
|
|
|
90
90
|
# update preconditioner
|
|
91
91
|
if step % update_freq == 0:
|
|
@@ -101,7 +101,7 @@ class Preconditioner(Transform):
|
|
|
101
101
|
|
|
102
102
|
# scale initial step, when preconditioner might not have been applied
|
|
103
103
|
if scale_first and step == 0:
|
|
104
|
-
tensors_vec
|
|
104
|
+
tensors_vec *= scale_factor
|
|
105
105
|
|
|
106
106
|
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
107
107
|
self.global_state['__step'] = step + 1
|
|
@@ -220,7 +220,7 @@ class AdaSOAP(Transform):
|
|
|
220
220
|
state['step'] = 0
|
|
221
221
|
updates.append(tensors[i].clip(-0.1,0.1))
|
|
222
222
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
223
|
-
#
|
|
223
|
+
# that can mess with other modules scaling
|
|
224
224
|
|
|
225
225
|
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
226
226
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# idea https://arxiv.org/pdf/2212.09841
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...core import Chainable, Module, apply
|
|
10
|
+
from ...utils import TensorList, vec_to_tensors
|
|
11
|
+
from ...utils.derivatives import (
|
|
12
|
+
hessian_list_to_mat,
|
|
13
|
+
hessian_mat,
|
|
14
|
+
hvp,
|
|
15
|
+
hvp_fd_central,
|
|
16
|
+
hvp_fd_forward,
|
|
17
|
+
jacobian_and_hessian_wrt,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StructuredNewton(Module):
|
|
22
|
+
"""TODO
|
|
23
|
+
Args:
|
|
24
|
+
structure (str, optional): structure.
|
|
25
|
+
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
26
|
+
hvp_method (str):
|
|
27
|
+
how to calculate hvp_method. Defaults to "autograd".
|
|
28
|
+
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
structure: Literal[
|
|
34
|
+
"diagonal",
|
|
35
|
+
"diagonal1",
|
|
36
|
+
"diagonal_abs",
|
|
37
|
+
"tridiagonal",
|
|
38
|
+
"circulant",
|
|
39
|
+
"toeplitz",
|
|
40
|
+
"toeplitz_like",
|
|
41
|
+
"hankel",
|
|
42
|
+
"rank1",
|
|
43
|
+
"rank2", # any rank
|
|
44
|
+
]
|
|
45
|
+
| str = "diagonal",
|
|
46
|
+
reg: float = 1e-6,
|
|
47
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
+
h: float = 1e-3,
|
|
49
|
+
inner: Chainable | None = None,
|
|
50
|
+
):
|
|
51
|
+
defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
|
|
52
|
+
super().__init__(defaults)
|
|
53
|
+
|
|
54
|
+
if inner is not None:
|
|
55
|
+
self.set_child('inner', inner)
|
|
56
|
+
|
|
57
|
+
@torch.no_grad
|
|
58
|
+
def step(self, vars):
|
|
59
|
+
params = TensorList(vars.params)
|
|
60
|
+
closure = vars.closure
|
|
61
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
62
|
+
|
|
63
|
+
settings = self.settings[params[0]]
|
|
64
|
+
reg = settings['reg']
|
|
65
|
+
hvp_method = settings['hvp_method']
|
|
66
|
+
structure = settings['structure']
|
|
67
|
+
h = settings['h']
|
|
68
|
+
|
|
69
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
70
|
+
if hvp_method == 'autograd':
|
|
71
|
+
grad = vars.get_grad(create_graph=True)
|
|
72
|
+
def Hvp_fn1(x):
|
|
73
|
+
return hvp(params, grad, x, retain_graph=True)
|
|
74
|
+
Hvp_fn = Hvp_fn1
|
|
75
|
+
|
|
76
|
+
elif hvp_method == 'forward':
|
|
77
|
+
grad = vars.get_grad()
|
|
78
|
+
def Hvp_fn2(x):
|
|
79
|
+
return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
|
|
80
|
+
Hvp_fn = Hvp_fn2
|
|
81
|
+
|
|
82
|
+
elif hvp_method == 'central':
|
|
83
|
+
grad = vars.get_grad()
|
|
84
|
+
def Hvp_fn3(x):
|
|
85
|
+
return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
|
|
86
|
+
Hvp_fn = Hvp_fn3
|
|
87
|
+
|
|
88
|
+
else: raise ValueError(hvp_method)
|
|
89
|
+
|
|
90
|
+
# -------------------------------- inner step -------------------------------- #
|
|
91
|
+
update = vars.get_update()
|
|
92
|
+
if 'inner' in self.children:
|
|
93
|
+
update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
|
|
94
|
+
|
|
95
|
+
# hessian
|
|
96
|
+
if structure.startswith('diagonal'):
|
|
97
|
+
H = Hvp_fn([torch.ones_like(p) for p in params])
|
|
98
|
+
if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
|
|
99
|
+
if structure == 'diagonal_abs': torch._foreach_abs_(H)
|
|
100
|
+
torch._foreach_add_(H, reg)
|
|
101
|
+
torch._foreach_div_(update, H)
|
|
102
|
+
vars.update = update
|
|
103
|
+
return vars
|
|
104
|
+
|
|
105
|
+
# hessian
|
|
106
|
+
raise NotImplementedError(structure)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
@@ -64,7 +64,7 @@ class ConguateGradientBase(Transform, ABC):
|
|
|
64
64
|
# ------------------------------- Polak-Ribière ------------------------------ #
|
|
65
65
|
def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
|
|
66
66
|
denom = prev_g.dot(prev_g)
|
|
67
|
-
if denom
|
|
67
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
68
68
|
return g.dot(g - prev_g) / denom
|
|
69
69
|
|
|
70
70
|
class PolakRibiere(ConguateGradientBase):
|
|
@@ -76,8 +76,8 @@ class PolakRibiere(ConguateGradientBase):
|
|
|
76
76
|
return polak_ribiere_beta(g, prev_g)
|
|
77
77
|
|
|
78
78
|
# ------------------------------ Fletcher–Reeves ----------------------------- #
|
|
79
|
-
def fletcher_reeves_beta(gg, prev_gg):
|
|
80
|
-
if prev_gg
|
|
79
|
+
def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
|
|
80
|
+
if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
|
|
81
81
|
return gg / prev_gg
|
|
82
82
|
|
|
83
83
|
class FletcherReeves(ConguateGradientBase):
|
|
@@ -98,7 +98,7 @@ class FletcherReeves(ConguateGradientBase):
|
|
|
98
98
|
def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
99
99
|
grad_diff = g - prev_g
|
|
100
100
|
denom = prev_d.dot(grad_diff)
|
|
101
|
-
if denom
|
|
101
|
+
if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
|
|
102
102
|
return (g.dot(grad_diff) / denom).neg()
|
|
103
103
|
|
|
104
104
|
|
|
@@ -114,7 +114,7 @@ class HestenesStiefel(ConguateGradientBase):
|
|
|
114
114
|
# --------------------------------- Dai–Yuan --------------------------------- #
|
|
115
115
|
def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
116
116
|
denom = prev_d.dot(g - prev_g)
|
|
117
|
-
if denom
|
|
117
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
118
118
|
return (g.dot(g) / denom).neg()
|
|
119
119
|
|
|
120
120
|
class DaiYuan(ConguateGradientBase):
|
|
@@ -129,7 +129,7 @@ class DaiYuan(ConguateGradientBase):
|
|
|
129
129
|
# -------------------------------- Liu-Storey -------------------------------- #
|
|
130
130
|
def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
|
|
131
131
|
denom = prev_g.dot(prev_d)
|
|
132
|
-
if denom
|
|
132
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
133
133
|
return g.dot(g - prev_g) / denom
|
|
134
134
|
|
|
135
135
|
class LiuStorey(ConguateGradientBase):
|
|
@@ -159,7 +159,7 @@ class ConjugateDescent(Transform):
|
|
|
159
159
|
self.global_state['denom'] = torch.tensor(0.).to(g[0])
|
|
160
160
|
|
|
161
161
|
prev_gd = self.global_state.get('prev_gd', 0)
|
|
162
|
-
if prev_gd
|
|
162
|
+
if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
|
|
163
163
|
else: beta = g.dot(g) / prev_gd
|
|
164
164
|
|
|
165
165
|
# inner step
|
|
@@ -176,7 +176,7 @@ class ConjugateDescent(Transform):
|
|
|
176
176
|
def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
|
|
177
177
|
g_diff = g - prev_g
|
|
178
178
|
denom = prev_d.dot(g_diff)
|
|
179
|
-
if denom
|
|
179
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
180
180
|
|
|
181
181
|
term1 = 1/denom
|
|
182
182
|
# term2
|
|
@@ -198,7 +198,7 @@ class HagerZhang(ConguateGradientBase):
|
|
|
198
198
|
def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
|
|
199
199
|
grad_diff = g - prev_g
|
|
200
200
|
denom = prev_d.dot(grad_diff)
|
|
201
|
-
if denom
|
|
201
|
+
if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
|
|
202
202
|
|
|
203
203
|
# Dai-Yuan
|
|
204
204
|
dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
|
|
@@ -38,9 +38,9 @@ def lbfgs(
|
|
|
38
38
|
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
39
39
|
|
|
40
40
|
# initial step size guess modified from pytorch L-BFGS
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return tensors_.mul_(
|
|
41
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
42
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
43
|
+
return tensors_.mul_(scale_factor)
|
|
44
44
|
|
|
45
45
|
else:
|
|
46
46
|
# 1st loop
|
|
@@ -17,9 +17,9 @@ def lsr1_(
|
|
|
17
17
|
):
|
|
18
18
|
if step == 0 or not s_history:
|
|
19
19
|
# initial step size guess from pytorch
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
return tensors_.mul_(
|
|
20
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
21
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
22
|
+
return tensors_.mul_(scale_factor)
|
|
23
23
|
|
|
24
24
|
m = len(s_history)
|
|
25
25
|
|
|
@@ -65,9 +65,10 @@ def lsr1_(
|
|
|
65
65
|
Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
|
|
66
66
|
|
|
67
67
|
if scale_second and step == 1:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
Hx.mul_(
|
|
68
|
+
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
69
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
70
|
+
Hx.mul_(scale_factor)
|
|
71
|
+
|
|
71
72
|
return Hx
|
|
72
73
|
|
|
73
74
|
|
|
@@ -122,7 +122,9 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
122
122
|
step = state.get('step', 0)
|
|
123
123
|
|
|
124
124
|
if settings['scale_second'] and step == 2:
|
|
125
|
-
|
|
125
|
+
scale_factor = 1 / tensor.abs().sum().clip(min=1)
|
|
126
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
|
|
127
|
+
tensor = tensor * scale_factor
|
|
126
128
|
|
|
127
129
|
inverse = settings['inverse']
|
|
128
130
|
if inverse:
|
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
import warnings
|
|
2
|
+
from collections.abc import Callable
|
|
2
3
|
from functools import partial
|
|
3
4
|
from typing import Literal
|
|
4
|
-
|
|
5
|
+
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import
|
|
8
|
+
from ...core import Chainable, Module, apply
|
|
9
|
+
from ...utils import TensorList, vec_to_tensors
|
|
9
10
|
from ...utils.derivatives import (
|
|
10
11
|
hessian_list_to_mat,
|
|
11
12
|
hessian_mat,
|
|
13
|
+
hvp,
|
|
14
|
+
hvp_fd_central,
|
|
15
|
+
hvp_fd_forward,
|
|
12
16
|
jacobian_and_hessian_wrt,
|
|
13
17
|
)
|
|
14
18
|
|
|
@@ -117,9 +121,10 @@ class Newton(Module):
|
|
|
117
121
|
raise ValueError(hessian_method)
|
|
118
122
|
|
|
119
123
|
# -------------------------------- inner step -------------------------------- #
|
|
124
|
+
update = vars.get_update()
|
|
120
125
|
if 'inner' in self.children:
|
|
121
|
-
|
|
122
|
-
g = torch.cat([t.view(-1) for t in
|
|
126
|
+
update = apply(self.children['inner'], update, params=params, grads=list(g_list), vars=vars)
|
|
127
|
+
g = torch.cat([t.view(-1) for t in update])
|
|
123
128
|
|
|
124
129
|
# ------------------------------- regulazition ------------------------------- #
|
|
125
130
|
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
@@ -139,4 +144,4 @@ class Newton(Module):
|
|
|
139
144
|
if update is None: update = least_squares_solve(H, g)
|
|
140
145
|
|
|
141
146
|
vars.update = vec_to_tensors(update, params)
|
|
142
|
-
return vars
|
|
147
|
+
return vars
|
|
@@ -66,9 +66,9 @@ class NewtonCG(Module):
|
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
# -------------------------------- inner step -------------------------------- #
|
|
69
|
-
b =
|
|
69
|
+
b = vars.get_update()
|
|
70
70
|
if 'inner' in self.children:
|
|
71
|
-
b = as_tensorlist(apply(self.children['inner'],
|
|
71
|
+
b = as_tensorlist(apply(self.children['inner'], b, params=params, grads=grad, vars=vars))
|
|
72
72
|
|
|
73
73
|
# ---------------------------------- run cg ---------------------------------- #
|
|
74
74
|
x0 = None
|
|
@@ -15,7 +15,7 @@ class NystromSketchAndSolve(Module):
|
|
|
15
15
|
rank: int,
|
|
16
16
|
reg: float = 1e-3,
|
|
17
17
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
-
h=1e-
|
|
18
|
+
h=1e-2,
|
|
19
19
|
inner: Chainable | None = None,
|
|
20
20
|
seed: int | None = None,
|
|
21
21
|
):
|
|
@@ -74,9 +74,9 @@ class NystromSketchAndSolve(Module):
|
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
# -------------------------------- inner step -------------------------------- #
|
|
77
|
-
b =
|
|
77
|
+
b = vars.get_update()
|
|
78
78
|
if 'inner' in self.children:
|
|
79
|
-
b = apply(self.children['inner'],
|
|
79
|
+
b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
|
|
80
80
|
|
|
81
81
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
82
82
|
x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
|
|
@@ -93,7 +93,7 @@ class NystromPCG(Module):
|
|
|
93
93
|
tol=1e-3,
|
|
94
94
|
reg: float = 1e-6,
|
|
95
95
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
96
|
-
h=1e-
|
|
96
|
+
h=1e-2,
|
|
97
97
|
inner: Chainable | None = None,
|
|
98
98
|
seed: int | None = None,
|
|
99
99
|
):
|
|
@@ -156,9 +156,9 @@ class NystromPCG(Module):
|
|
|
156
156
|
|
|
157
157
|
|
|
158
158
|
# -------------------------------- inner step -------------------------------- #
|
|
159
|
-
b =
|
|
159
|
+
b = vars.get_update()
|
|
160
160
|
if 'inner' in self.children:
|
|
161
|
-
b = apply(self.children['inner'],
|
|
161
|
+
b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
|
|
162
162
|
|
|
163
163
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
164
164
|
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
@@ -36,6 +36,7 @@ torchzero/modules/experimental/newton_solver.py
|
|
|
36
36
|
torchzero/modules/experimental/reduce_outward_lr.py
|
|
37
37
|
torchzero/modules/experimental/soapy.py
|
|
38
38
|
torchzero/modules/experimental/spectral.py
|
|
39
|
+
torchzero/modules/experimental/structured_newton.py
|
|
39
40
|
torchzero/modules/experimental/subspace_preconditioners.py
|
|
40
41
|
torchzero/modules/experimental/tropical_newton.py
|
|
41
42
|
torchzero/modules/grad_approximation/__init__.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/experimental/subspace_preconditioners.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/forward_gradient.py
RENAMED
|
File without changes
|
{torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/grad_approximation/grad_approximator.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchzero-0.3.8 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|