torchzero 0.3.6__tar.gz → 0.3.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchzero-0.3.6 → torchzero-0.3.9}/PKG-INFO +2 -2
- {torchzero-0.3.6 → torchzero-0.3.9}/README.md +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/pyproject.toml +2 -2
- {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_opts.py +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_tensorlist.py +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/preconditioner.py +12 -11
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/__init__.py +3 -2
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/absoap.py +8 -2
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/adadam.py +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/adamY.py +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/adasoap.py +2 -2
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/algebraic_newton.py +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/curveball.py +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/gradmin.py +1 -1
- torchzero-0.3.9/torchzero/modules/experimental/newton_solver.py +88 -0
- torchzero-0.3.6/torchzero/modules/experimental/dsoap.py → torchzero-0.3.9/torchzero/modules/experimental/soapy.py +4 -4
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/spectral.py +5 -3
- torchzero-0.3.9/torchzero/modules/experimental/structured_newton.py +111 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/subspace_preconditioners.py +16 -9
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/soap.py +1 -2
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/projection.py +27 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/cg.py +9 -9
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lbfgs.py +4 -3
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/lsr1.py +7 -3
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/quasi_newton.py +18 -17
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/__init__.py +1 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/newton.py +11 -6
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/newton_cg.py +3 -3
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/second_order/nystrom.py +6 -6
- torchzero-0.3.9/torchzero/utils/linalg/benchmark.py +20 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/solve.py +15 -14
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/PKG-INFO +2 -2
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/SOURCES.txt +4 -1
- {torchzero-0.3.6 → torchzero-0.3.9}/LICENSE +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/docs/source/conf.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/setup.cfg +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_identical.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_module.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_utils_optimizer.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/tests/test_vars.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/module.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/core/transform.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/clipping.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/ema_clipping.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/clipping/growth_clipping.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/reduce_outward_lr.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/tropical_newton.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/functional.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/fdm.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/forward_gradient.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/grad_approximator.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/grad_approximation/rfdm.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/backtracking.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/line_search.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/scipy.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/strong_wolfe.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/line_search/trust_region.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/lr/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/lr/lr.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/lr/step_size.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/averaging.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/cautious.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/ema.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/experimental.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/matrix_momentum.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/momentum/momentum.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/accumulate.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/binary.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/debug.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/misc.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/multi.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/reduce.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/split.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/switch.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/unary.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/ops/utility.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/adagrad.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/adam.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/lion.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/muon.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/orthograd.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/rmsprop.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/rprop.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/shampoo.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/optimizers/sophia_h.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/dct.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/fft.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/galore.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/projections/structural.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/quasi_newton/olbfgs.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/smoothing/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/smoothing/gaussian.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/smoothing/laplacian.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/weight_decay/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/weight_decay/weight_decay.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/wrappers/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/wrappers/optim_wrapper.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/utility/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/utility/split.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/nevergrad.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/nlopt.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/optim/wrappers/scipy.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/compile.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/derivatives.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/__init__.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/matrix_funcs.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/orthogonalize.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/qr.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/linalg/svd.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/numberlist.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/ops.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/optimizer.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/optuna_tools.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/params.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/python_tools.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/tensorlist.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero/utils/torch_tools.py +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/dependency_links.txt +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/requires.txt +0 -0
- {torchzero-0.3.6 → torchzero-0.3.9}/torchzero.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchzero
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.9
|
|
4
4
|
Summary: Modular optimization library for PyTorch.
|
|
5
5
|
Author-email: Ivan Nikishev <nkshv2@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -156,7 +156,7 @@ for epoch in range(100):
|
|
|
156
156
|
* `Newton`: Classic Newton's method.
|
|
157
157
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
158
158
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
159
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning (
|
|
159
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
160
160
|
|
|
161
161
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
162
162
|
* `LBFGS`: Limited-memory BFGS.
|
|
@@ -117,7 +117,7 @@ for epoch in range(100):
|
|
|
117
117
|
* `Newton`: Classic Newton's method.
|
|
118
118
|
* `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
|
|
119
119
|
* `NystromSketchAndSolve`: Nyström sketch-and-solve method.
|
|
120
|
-
* `NystromPCG`: NewtonCG with Nyström preconditioning (
|
|
120
|
+
* `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
|
|
121
121
|
|
|
122
122
|
* **Quasi-Newton**: Approximate second-order optimization methods.
|
|
123
123
|
* `LBFGS`: Limited-memory BFGS.
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# STEP 1 - COMMIT NEW CHANGES BUT DON'T PUSH THEM YET
|
|
3
3
|
# STEP 2 - BUMP VERSION AND COMMIT IT (DONT PUSH!!!!)
|
|
4
4
|
# STEP 3 - CREATE TAG WITH THAT VERSION
|
|
5
|
-
# STEP 4 - PUSH CHANGES
|
|
5
|
+
# STEP 4 - PUSH (SYNC) CHANGES
|
|
6
6
|
# STEP 5 - PUSH TAG
|
|
7
7
|
|
|
8
8
|
[build-system]
|
|
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
|
|
|
13
13
|
name = "torchzero"
|
|
14
14
|
description = "Modular optimization library for PyTorch."
|
|
15
15
|
|
|
16
|
-
version = "0.3.
|
|
16
|
+
version = "0.3.9"
|
|
17
17
|
dependencies = [
|
|
18
18
|
"torch",
|
|
19
19
|
"numpy",
|
|
@@ -745,7 +745,7 @@ SSVM = Run(
|
|
|
745
745
|
func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
746
746
|
sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
|
|
747
747
|
needs_closure=True,
|
|
748
|
-
func='rosen', steps=50, loss=1e-
|
|
748
|
+
func='rosen', steps=50, loss=1e-10, merge_invariant=True,
|
|
749
749
|
sphere_steps=10, sphere_loss=0,
|
|
750
750
|
)
|
|
751
751
|
|
|
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
|
|
|
835
835
|
expected = vec_equiv_func()
|
|
836
836
|
|
|
837
837
|
if isinstance(result, bool): assert result == expected
|
|
838
|
-
else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
|
|
838
|
+
else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
|
|
839
839
|
|
|
840
840
|
|
|
841
841
|
def test_global_vector_norm(simple_tl: TensorList):
|
|
@@ -38,17 +38,18 @@ class Preconditioner(Transform):
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
41
|
-
step = self.global_state.get('
|
|
41
|
+
step = self.global_state.get('__step', 0)
|
|
42
42
|
states = [self.state[p] for p in params]
|
|
43
43
|
settings = [self.settings[p] for p in params]
|
|
44
44
|
global_settings = settings[0]
|
|
45
45
|
update_freq = global_settings['__update_freq']
|
|
46
46
|
|
|
47
47
|
scale_first = global_settings['__scale_first']
|
|
48
|
-
scale_factor =
|
|
48
|
+
scale_factor = 1
|
|
49
49
|
if scale_first and step == 0:
|
|
50
50
|
# initial step size guess from pytorch LBFGS
|
|
51
|
-
scale_factor = TensorList(tensors).abs().
|
|
51
|
+
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
52
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
52
53
|
|
|
53
54
|
# update preconditioner
|
|
54
55
|
if step % update_freq == 0:
|
|
@@ -63,13 +64,13 @@ class Preconditioner(Transform):
|
|
|
63
64
|
|
|
64
65
|
# scale initial step, when preconditioner might not have been applied
|
|
65
66
|
if scale_first and step == 0:
|
|
66
|
-
torch.
|
|
67
|
+
torch._foreach_mul_(tensors, scale_factor)
|
|
67
68
|
|
|
68
|
-
self.global_state['
|
|
69
|
+
self.global_state['__step'] = step + 1
|
|
69
70
|
return tensors
|
|
70
71
|
|
|
71
72
|
def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
|
|
72
|
-
step = self.global_state.get('
|
|
73
|
+
step = self.global_state.get('__step', 0)
|
|
73
74
|
tensors_vec = torch.cat([t.ravel() for t in tensors])
|
|
74
75
|
params_vec = torch.cat([p.ravel() for p in params])
|
|
75
76
|
grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
|
|
@@ -80,10 +81,11 @@ class Preconditioner(Transform):
|
|
|
80
81
|
update_freq = global_settings['__update_freq']
|
|
81
82
|
|
|
82
83
|
scale_first = global_settings['__scale_first']
|
|
83
|
-
scale_factor =
|
|
84
|
+
scale_factor = 1
|
|
84
85
|
if scale_first and step == 0:
|
|
85
86
|
# initial step size guess from pytorch LBFGS
|
|
86
|
-
scale_factor = tensors_vec.abs().sum()
|
|
87
|
+
scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
|
|
88
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
|
|
87
89
|
|
|
88
90
|
# update preconditioner
|
|
89
91
|
if step % update_freq == 0:
|
|
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
|
|
|
99
101
|
|
|
100
102
|
# scale initial step, when preconditioner might not have been applied
|
|
101
103
|
if scale_first and step == 0:
|
|
102
|
-
|
|
103
|
-
tensors_vec /= scale_factor
|
|
104
|
+
tensors_vec *= scale_factor
|
|
104
105
|
|
|
105
106
|
tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
|
|
106
|
-
self.global_state['
|
|
107
|
+
self.global_state['__step'] = step + 1
|
|
107
108
|
return tensors
|
|
108
109
|
|
|
109
110
|
@torch.no_grad
|
|
@@ -3,7 +3,7 @@ from .adadam import Adadam
|
|
|
3
3
|
from .adamY import AdamY
|
|
4
4
|
from .adasoap import AdaSOAP
|
|
5
5
|
from .curveball import CurveBall
|
|
6
|
-
from .
|
|
6
|
+
from .soapy import SOAPY
|
|
7
7
|
from .gradmin import GradMin
|
|
8
8
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
9
9
|
from .spectral import SpectralPreconditioner
|
|
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
|
|
|
11
11
|
HistorySubspacePreconditioning,
|
|
12
12
|
RandomSubspacePreconditioning,
|
|
13
13
|
)
|
|
14
|
-
from .tropical_newton import TropicalNewton
|
|
14
|
+
from .tropical_newton import TropicalNewton
|
|
15
|
+
from .newton_solver import NewtonSolver
|
|
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
|
|
|
140
140
|
class ABSOAP(Transform):
|
|
141
141
|
"""SOAP but with two extra letters included in its name in order to improve converence
|
|
142
142
|
|
|
143
|
+
so what you can do is choose what goes into what ,and that is supposed to be good.
|
|
144
|
+
|
|
143
145
|
new args
|
|
144
146
|
|
|
145
147
|
scale by s whether to scale gradient differences by parameter differences
|
|
146
148
|
|
|
147
149
|
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
150
|
+
|
|
151
|
+
okay I changed these args into another ones
|
|
152
|
+
|
|
153
|
+
BASICALLY THIS IS FOR MY EXPERIMENTS
|
|
148
154
|
"""
|
|
149
155
|
def __init__(
|
|
150
156
|
self,
|
|
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
|
|
|
213
219
|
if 'g_prev' not in state:
|
|
214
220
|
state['p_prev'] = p.clone()
|
|
215
221
|
state['g_prev'] = t.clone()
|
|
216
|
-
updates.append(tensors[i].
|
|
222
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
217
223
|
continue
|
|
218
224
|
|
|
219
225
|
p_prev = state['p_prev']
|
|
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
|
|
|
285
291
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
286
292
|
|
|
287
293
|
state['step'] = 0
|
|
288
|
-
updates.append(tensors[i].
|
|
294
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
289
295
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
290
296
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
291
297
|
|
|
@@ -218,9 +218,9 @@ class AdaSOAP(Transform):
|
|
|
218
218
|
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
219
219
|
|
|
220
220
|
state['step'] = 0
|
|
221
|
-
updates.append(tensors[i].
|
|
221
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
222
222
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
223
|
-
#
|
|
223
|
+
# that can mess with other modules scaling
|
|
224
224
|
|
|
225
225
|
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
226
226
|
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
class AlgebraicNewton(Module):
|
|
74
|
-
"""newton in other algebras, not
|
|
74
|
+
"""newton in other algebras, not that it works."""
|
|
75
75
|
def __init__(
|
|
76
76
|
self,
|
|
77
77
|
reg: float | None = None,
|
|
@@ -13,7 +13,7 @@ def curveball(
|
|
|
13
13
|
momentum: float | NumberList,
|
|
14
14
|
precond_lr: float | NumberList,
|
|
15
15
|
):
|
|
16
|
-
"""returns z_, clone it!!!"""
|
|
16
|
+
"""returns z_, clone it!!! (no just negate it)"""
|
|
17
17
|
delta = Hz + tensors
|
|
18
18
|
z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
|
|
19
19
|
return z_
|
|
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class GradMin(Reformulation):
|
|
17
|
-
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd.
|
|
17
|
+
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
20
|
loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
from typing import Any, Literal, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, apply, Modular
|
|
7
|
+
from ...utils import TensorList, as_tensorlist
|
|
8
|
+
from ...utils.derivatives import hvp
|
|
9
|
+
from ..quasi_newton import LBFGS
|
|
10
|
+
|
|
11
|
+
class NewtonSolver(Module):
|
|
12
|
+
"""Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
16
|
+
maxiter=None,
|
|
17
|
+
tol=1e-3,
|
|
18
|
+
reg: float = 0,
|
|
19
|
+
warm_start=True,
|
|
20
|
+
inner: Chainable | None = None,
|
|
21
|
+
):
|
|
22
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
|
|
23
|
+
super().__init__(defaults,)
|
|
24
|
+
|
|
25
|
+
if inner is not None:
|
|
26
|
+
self.set_child('inner', inner)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, vars):
|
|
30
|
+
params = TensorList(vars.params)
|
|
31
|
+
closure = vars.closure
|
|
32
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
33
|
+
|
|
34
|
+
settings = self.settings[params[0]]
|
|
35
|
+
solver_cls = settings['solver']
|
|
36
|
+
maxiter = settings['maxiter']
|
|
37
|
+
tol = settings['tol']
|
|
38
|
+
reg = settings['reg']
|
|
39
|
+
warm_start = settings['warm_start']
|
|
40
|
+
|
|
41
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
+
grad = vars.get_grad(create_graph=True)
|
|
43
|
+
|
|
44
|
+
def H_mm(x):
|
|
45
|
+
with torch.enable_grad():
|
|
46
|
+
Hvp = TensorList(hvp(params, grad, x, create_graph=True))
|
|
47
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
48
|
+
return Hvp
|
|
49
|
+
|
|
50
|
+
# -------------------------------- inner step -------------------------------- #
|
|
51
|
+
b = as_tensorlist(grad)
|
|
52
|
+
if 'inner' in self.children:
|
|
53
|
+
b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
|
|
54
|
+
|
|
55
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
56
|
+
x0 = None
|
|
57
|
+
if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
|
|
58
|
+
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
|
+
else: x = x0.clone().requires_grad_(True)
|
|
60
|
+
|
|
61
|
+
solver = solver_cls(x)
|
|
62
|
+
def lstsq_closure(backward=True):
|
|
63
|
+
Hx = H_mm(x)
|
|
64
|
+
loss = (Hx-b).pow(2).global_mean()
|
|
65
|
+
if backward:
|
|
66
|
+
solver.zero_grad()
|
|
67
|
+
loss.backward(inputs=x)
|
|
68
|
+
return loss
|
|
69
|
+
|
|
70
|
+
if maxiter is None: maxiter = b.global_numel()
|
|
71
|
+
loss = None
|
|
72
|
+
initial_loss = lstsq_closure(False)
|
|
73
|
+
if initial_loss > tol:
|
|
74
|
+
for i in range(maxiter):
|
|
75
|
+
loss = solver.step(lstsq_closure)
|
|
76
|
+
assert loss is not None
|
|
77
|
+
if min(loss, loss/initial_loss) < tol: break
|
|
78
|
+
|
|
79
|
+
print(f'{loss = }')
|
|
80
|
+
|
|
81
|
+
if warm_start:
|
|
82
|
+
assert x0 is not None
|
|
83
|
+
x0.copy_(x)
|
|
84
|
+
|
|
85
|
+
vars.update = x.detach()
|
|
86
|
+
return vars
|
|
87
|
+
|
|
88
|
+
|
|
@@ -3,7 +3,7 @@ from operator import itemgetter
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import Chainable, Transform, apply
|
|
6
|
-
from
|
|
6
|
+
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
7
|
|
|
8
8
|
@torch.no_grad
|
|
9
9
|
def update_soap_covariances_(
|
|
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
135
135
|
|
|
136
136
|
return final, exp_avg_sq
|
|
137
137
|
|
|
138
|
-
class
|
|
138
|
+
class SOAPY(Transform):
|
|
139
139
|
"""SOAP but uses scaled gradient differences
|
|
140
140
|
|
|
141
141
|
new args
|
|
@@ -195,7 +195,7 @@ class DSOAP(Transform):
|
|
|
195
195
|
if 'g_prev' not in state:
|
|
196
196
|
state['p_prev'] = p.clone()
|
|
197
197
|
state['g_prev'] = t.clone()
|
|
198
|
-
updates.append(tensors[i].
|
|
198
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
199
199
|
continue
|
|
200
200
|
|
|
201
201
|
p_prev = state['p_prev']
|
|
@@ -228,7 +228,7 @@ class DSOAP(Transform):
|
|
|
228
228
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
229
229
|
|
|
230
230
|
state['step'] = 0
|
|
231
|
-
updates.append(tensors[i].
|
|
231
|
+
updates.append(tensors[i].clip(-0.1,0.1))
|
|
232
232
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
233
233
|
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
234
234
|
|
|
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
194
194
|
order (int, optional):
|
|
195
195
|
whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
|
|
196
196
|
solver (str, optional): what to use for whitening. Defaults to 'svd'.
|
|
197
|
-
|
|
198
|
-
|
|
197
|
+
A_beta (float | None, optional):
|
|
198
|
+
beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
|
+
B_beta (float | None, optional):
|
|
200
|
+
beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
|
|
199
201
|
interval (int, optional): How often to update history. Defaults to 1 (every step).
|
|
200
202
|
concat_params (bool, optional):
|
|
201
203
|
whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
|
|
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
275
277
|
A = state.get('A', None)
|
|
276
278
|
if A is None:
|
|
277
279
|
# make a conservative step to avoid issues due to different GD scaling
|
|
278
|
-
return tensor.
|
|
280
|
+
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
279
281
|
|
|
280
282
|
B = state['B']
|
|
281
283
|
update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# idea https://arxiv.org/pdf/2212.09841
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...core import Chainable, Module, apply
|
|
10
|
+
from ...utils import TensorList, vec_to_tensors
|
|
11
|
+
from ...utils.derivatives import (
|
|
12
|
+
hessian_list_to_mat,
|
|
13
|
+
hessian_mat,
|
|
14
|
+
hvp,
|
|
15
|
+
hvp_fd_central,
|
|
16
|
+
hvp_fd_forward,
|
|
17
|
+
jacobian_and_hessian_wrt,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StructuredNewton(Module):
|
|
22
|
+
"""TODO
|
|
23
|
+
Args:
|
|
24
|
+
structure (str, optional): structure.
|
|
25
|
+
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
26
|
+
hvp_method (str):
|
|
27
|
+
how to calculate hvp_method. Defaults to "autograd".
|
|
28
|
+
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
structure: Literal[
|
|
34
|
+
"diagonal",
|
|
35
|
+
"diagonal1",
|
|
36
|
+
"diagonal_abs",
|
|
37
|
+
"tridiagonal",
|
|
38
|
+
"circulant",
|
|
39
|
+
"toeplitz",
|
|
40
|
+
"toeplitz_like",
|
|
41
|
+
"hankel",
|
|
42
|
+
"rank1",
|
|
43
|
+
"rank2", # any rank
|
|
44
|
+
]
|
|
45
|
+
| str = "diagonal",
|
|
46
|
+
reg: float = 1e-6,
|
|
47
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
+
h: float = 1e-3,
|
|
49
|
+
inner: Chainable | None = None,
|
|
50
|
+
):
|
|
51
|
+
defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
|
|
52
|
+
super().__init__(defaults)
|
|
53
|
+
|
|
54
|
+
if inner is not None:
|
|
55
|
+
self.set_child('inner', inner)
|
|
56
|
+
|
|
57
|
+
@torch.no_grad
|
|
58
|
+
def step(self, vars):
|
|
59
|
+
params = TensorList(vars.params)
|
|
60
|
+
closure = vars.closure
|
|
61
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
62
|
+
|
|
63
|
+
settings = self.settings[params[0]]
|
|
64
|
+
reg = settings['reg']
|
|
65
|
+
hvp_method = settings['hvp_method']
|
|
66
|
+
structure = settings['structure']
|
|
67
|
+
h = settings['h']
|
|
68
|
+
|
|
69
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
70
|
+
if hvp_method == 'autograd':
|
|
71
|
+
grad = vars.get_grad(create_graph=True)
|
|
72
|
+
def Hvp_fn1(x):
|
|
73
|
+
return hvp(params, grad, x, retain_graph=True)
|
|
74
|
+
Hvp_fn = Hvp_fn1
|
|
75
|
+
|
|
76
|
+
elif hvp_method == 'forward':
|
|
77
|
+
grad = vars.get_grad()
|
|
78
|
+
def Hvp_fn2(x):
|
|
79
|
+
return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
|
|
80
|
+
Hvp_fn = Hvp_fn2
|
|
81
|
+
|
|
82
|
+
elif hvp_method == 'central':
|
|
83
|
+
grad = vars.get_grad()
|
|
84
|
+
def Hvp_fn3(x):
|
|
85
|
+
return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
|
|
86
|
+
Hvp_fn = Hvp_fn3
|
|
87
|
+
|
|
88
|
+
else: raise ValueError(hvp_method)
|
|
89
|
+
|
|
90
|
+
# -------------------------------- inner step -------------------------------- #
|
|
91
|
+
update = vars.get_update()
|
|
92
|
+
if 'inner' in self.children:
|
|
93
|
+
update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
|
|
94
|
+
|
|
95
|
+
# hessian
|
|
96
|
+
if structure.startswith('diagonal'):
|
|
97
|
+
H = Hvp_fn([torch.ones_like(p) for p in params])
|
|
98
|
+
if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
|
|
99
|
+
if structure == 'diagonal_abs': torch._foreach_abs_(H)
|
|
100
|
+
torch._foreach_add_(H, reg)
|
|
101
|
+
torch._foreach_div_(update, H)
|
|
102
|
+
vars.update = update
|
|
103
|
+
return vars
|
|
104
|
+
|
|
105
|
+
# hessian
|
|
106
|
+
raise NotImplementedError(structure)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
{torchzero-0.3.6 → torchzero-0.3.9}/torchzero/modules/experimental/subspace_preconditioners.py
RENAMED
|
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
|
|
|
38
38
|
return basis @ update_projected # d
|
|
39
39
|
|
|
40
40
|
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""full matrix rmsprop in random subspace"""
|
|
42
|
-
def __init__(self, k: int, beta: float | None = 0.99):
|
|
43
|
-
defaults = dict(k=k, beta=beta)
|
|
41
|
+
"""full matrix rmsprop in random slowly changing subspace"""
|
|
42
|
+
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
|
+
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
44
|
super().__init__(defaults, uses_grad=False)
|
|
45
45
|
|
|
46
|
+
if inner is not None: self.set_child('inner', inner)
|
|
47
|
+
|
|
46
48
|
def transform(self, tensors, params, grads, vars):
|
|
47
49
|
settings = self.settings[params[0]]
|
|
48
50
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
49
51
|
k = settings['k']
|
|
50
52
|
beta = settings['beta']
|
|
53
|
+
basis_beta = settings['basis_beta']
|
|
51
54
|
|
|
52
55
|
if 'basis' not in self.global_state:
|
|
53
56
|
self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
56
59
|
basis = self.global_state['basis']
|
|
57
60
|
accumulator = self.global_state['accumulator']
|
|
58
61
|
|
|
62
|
+
if basis_beta is not None:
|
|
63
|
+
basis.lerp_(torch.randn_like(basis), 1-basis_beta)
|
|
64
|
+
|
|
59
65
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
66
|
+
|
|
67
|
+
if 'inner' in self.children:
|
|
68
|
+
tensors = apply(self.children['inner'], tensors, params, grads, vars)
|
|
69
|
+
g = torch.cat([t.view(-1) for t in tensors])
|
|
70
|
+
|
|
60
71
|
try:
|
|
61
72
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
62
73
|
except torch.linalg.LinAlgError:
|
|
63
|
-
|
|
64
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
65
|
-
preconditioned = g / g.abs().sum()
|
|
74
|
+
preconditioned = g.clip(-0.1, 0.1)
|
|
66
75
|
vec_to_tensors_(preconditioned, tensors)
|
|
67
76
|
|
|
68
77
|
return tensors
|
|
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
119
128
|
try:
|
|
120
129
|
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
121
130
|
except torch.linalg.LinAlgError:
|
|
122
|
-
|
|
123
|
-
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
124
|
-
preconditioned = g / g.abs().sum()
|
|
131
|
+
preconditioned = g.clip(-0.1,0.1)
|
|
125
132
|
vec_to_tensors_(preconditioned, tensors)
|
|
126
133
|
|
|
127
134
|
return tensors
|
|
@@ -222,8 +222,7 @@ class SOAP(Transform):
|
|
|
222
222
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
223
223
|
|
|
224
224
|
state['step'] = 0
|
|
225
|
-
updates.append(tensors[i].
|
|
226
|
-
# updates.append(tensors[i] / tensors[i].abs().sum())
|
|
225
|
+
updates.append(tensors[i].clip(-0.1, 0.1))
|
|
227
226
|
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
228
227
|
# I use scaled update instead as to not mess up with next modules.
|
|
229
228
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from functools import partial
|
|
2
3
|
from abc import ABC, abstractmethod
|
|
3
4
|
from collections.abc import Iterable
|
|
4
5
|
from typing import Any, Literal
|
|
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
|
|
|
33
34
|
|
|
34
35
|
return projected_closure
|
|
35
36
|
|
|
37
|
+
def _projected_get_grad_override(
|
|
38
|
+
retain_graph: bool | None = None,
|
|
39
|
+
create_graph: bool = False,
|
|
40
|
+
projection: Any = ...,
|
|
41
|
+
unprojected_vars: Any = ...,
|
|
42
|
+
self: Any = ...,
|
|
43
|
+
):
|
|
44
|
+
assert isinstance(projection, Projection)
|
|
45
|
+
assert isinstance(unprojected_vars, Vars)
|
|
46
|
+
assert isinstance(self, Vars)
|
|
47
|
+
|
|
48
|
+
if self.grad is not None: return self.grad
|
|
49
|
+
grads = unprojected_vars.get_grad(retain_graph, create_graph)
|
|
50
|
+
projected_grads = list(projection.project(grads, self, current='grads'))
|
|
51
|
+
self.grad = projected_grads
|
|
52
|
+
for p, g in zip(self.params, projected_grads):
|
|
53
|
+
p.grad = g
|
|
54
|
+
return self.grad
|
|
55
|
+
|
|
36
56
|
|
|
37
57
|
class Projection(Module, ABC):
|
|
38
58
|
"""
|
|
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
|
|
|
137
157
|
|
|
138
158
|
# step
|
|
139
159
|
projected_vars.params = self._projected_params
|
|
160
|
+
projected_vars.get_grad = partial(
|
|
161
|
+
_projected_get_grad_override,
|
|
162
|
+
projection=self,
|
|
163
|
+
unprojected_vars=vars,
|
|
164
|
+
self=projected_vars,
|
|
165
|
+
)
|
|
140
166
|
projected_vars = self.children['modules'].step(projected_vars)
|
|
141
167
|
|
|
142
168
|
# empty fake params storage
|
|
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
|
|
|
149
175
|
unprojected_vars = projected_vars.clone(clone_update=False)
|
|
150
176
|
unprojected_vars.closure = vars.closure
|
|
151
177
|
unprojected_vars.params = vars.params
|
|
152
|
-
|
|
178
|
+
unprojected_vars.grad = vars.grad
|
|
153
179
|
|
|
154
180
|
if self._project_update:
|
|
155
181
|
assert projected_vars.update is not None
|