torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -3,10 +3,9 @@ from typing import Literal, overload
|
|
|
3
3
|
import torch
|
|
4
4
|
from scipy.sparse.linalg import LinearOperator, gcrotmk
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module,
|
|
7
|
-
from ...utils import
|
|
8
|
-
from ...utils.derivatives import
|
|
9
|
-
from ...utils.linalg.solve import cg, minres
|
|
6
|
+
from ...core import Chainable, Module, step
|
|
7
|
+
from ...utils import TensorList, vec_to_tensors
|
|
8
|
+
from ...utils.derivatives import hvp_fd_central, hvp_fd_forward
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
class ScipyNewtonCG(Module):
|
|
@@ -14,7 +13,7 @@ class ScipyNewtonCG(Module):
|
|
|
14
13
|
def __init__(
|
|
15
14
|
self,
|
|
16
15
|
solver = gcrotmk,
|
|
17
|
-
hvp_method: Literal["
|
|
16
|
+
hvp_method: Literal["fd_forward", "fd_central", "autograd"] = "autograd",
|
|
18
17
|
h: float = 1e-3,
|
|
19
18
|
warm_start=False,
|
|
20
19
|
inner: Chainable | None = None,
|
|
@@ -33,47 +32,47 @@ class ScipyNewtonCG(Module):
|
|
|
33
32
|
self._kwargs = kwargs
|
|
34
33
|
|
|
35
34
|
@torch.no_grad
|
|
36
|
-
def
|
|
37
|
-
params = TensorList(
|
|
38
|
-
closure =
|
|
35
|
+
def apply(self, objective):
|
|
36
|
+
params = TensorList(objective.params)
|
|
37
|
+
closure = objective.closure
|
|
39
38
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
40
39
|
|
|
41
|
-
|
|
42
|
-
hvp_method =
|
|
43
|
-
solver =
|
|
44
|
-
h =
|
|
45
|
-
warm_start =
|
|
40
|
+
fs = self.settings[params[0]]
|
|
41
|
+
hvp_method = fs['hvp_method']
|
|
42
|
+
solver = fs['solver']
|
|
43
|
+
h = fs['h']
|
|
44
|
+
warm_start = fs['warm_start']
|
|
46
45
|
|
|
47
46
|
self._num_hvps_last_step = 0
|
|
48
47
|
# ---------------------- Hessian vector product function --------------------- #
|
|
49
48
|
device = params[0].device; dtype=params[0].dtype
|
|
50
49
|
if hvp_method == 'autograd':
|
|
51
|
-
grad =
|
|
50
|
+
grad = objective.get_grads(create_graph=True)
|
|
52
51
|
|
|
53
52
|
def H_mm(x_np):
|
|
54
53
|
self._num_hvps_last_step += 1
|
|
55
54
|
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
56
55
|
with torch.enable_grad():
|
|
57
|
-
Hvp = TensorList(
|
|
56
|
+
Hvp = TensorList(torch.autograd.grad(grad, params, x, retain_graph=True))
|
|
58
57
|
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
59
58
|
|
|
60
59
|
else:
|
|
61
60
|
|
|
62
61
|
with torch.enable_grad():
|
|
63
|
-
grad =
|
|
62
|
+
grad = objective.get_grads()
|
|
64
63
|
|
|
65
64
|
if hvp_method == 'forward':
|
|
66
65
|
def H_mm(x_np):
|
|
67
66
|
self._num_hvps_last_step += 1
|
|
68
67
|
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
69
|
-
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad
|
|
68
|
+
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad)[1])
|
|
70
69
|
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
71
70
|
|
|
72
71
|
elif hvp_method == 'central':
|
|
73
72
|
def H_mm(x_np):
|
|
74
73
|
self._num_hvps_last_step += 1
|
|
75
74
|
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
76
|
-
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h
|
|
75
|
+
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h)[1])
|
|
77
76
|
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
78
77
|
|
|
79
78
|
else:
|
|
@@ -83,10 +82,8 @@ class ScipyNewtonCG(Module):
|
|
|
83
82
|
H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
|
|
84
83
|
|
|
85
84
|
# -------------------------------- inner step -------------------------------- #
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
89
|
-
b = as_tensorlist(b)
|
|
85
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
86
|
+
b = TensorList(objective.get_updates())
|
|
90
87
|
|
|
91
88
|
# ---------------------------------- run cg ---------------------------------- #
|
|
92
89
|
x0 = None
|
|
@@ -98,8 +95,8 @@ class ScipyNewtonCG(Module):
|
|
|
98
95
|
if warm_start:
|
|
99
96
|
self.global_state['x_prev'] = x_np
|
|
100
97
|
|
|
101
|
-
|
|
98
|
+
objective.updates = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
|
|
102
99
|
|
|
103
100
|
self._num_hvps += self._num_hvps_last_step
|
|
104
|
-
return
|
|
101
|
+
return objective
|
|
105
102
|
|
|
@@ -38,15 +38,15 @@ class SPSA1(GradApproximator):
|
|
|
38
38
|
super().__init__(defaults, target=target)
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def pre_step(self,
|
|
41
|
+
def pre_step(self, objective):
|
|
42
42
|
|
|
43
43
|
if self.defaults['pre_generate']:
|
|
44
44
|
|
|
45
|
-
params = TensorList(
|
|
45
|
+
params = TensorList(objective.params)
|
|
46
46
|
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
47
47
|
|
|
48
48
|
n_samples = self.defaults['n_samples']
|
|
49
|
-
h = self.get_settings(
|
|
49
|
+
h = self.get_settings(objective.params, 'h')
|
|
50
50
|
|
|
51
51
|
perturbations = [params.rademacher_like(generator=generator) for _ in range(n_samples)]
|
|
52
52
|
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
2
|
import torch
|
|
5
3
|
|
|
6
4
|
from ...core import Chainable
|
|
7
|
-
from ...utils import vec_to_tensors
|
|
8
|
-
from ..adaptive.shampoo import _merge_small_dims
|
|
5
|
+
from ...utils import vec_to_tensors
|
|
9
6
|
from ..projections import ProjectionBase
|
|
10
7
|
|
|
11
8
|
|
torchzero/modules/functional.py
CHANGED
|
@@ -30,7 +30,7 @@ def debiased_step_size(
|
|
|
30
30
|
pow: float = 2,
|
|
31
31
|
alpha: float | NumberList = 1,
|
|
32
32
|
):
|
|
33
|
-
"""returns multiplier to step size"""
|
|
33
|
+
"""returns multiplier to step size, step starts from 1"""
|
|
34
34
|
if isinstance(beta1, NumberList): beta1 = beta1.fill_none(0)
|
|
35
35
|
if isinstance(beta2, NumberList): beta2 = beta2.fill_none(0)
|
|
36
36
|
|
|
@@ -52,11 +52,11 @@ class ForwardGradient(RandomizedFDM):
|
|
|
52
52
|
params = TensorList(params)
|
|
53
53
|
loss_approx = None
|
|
54
54
|
|
|
55
|
-
|
|
56
|
-
n_samples =
|
|
57
|
-
jvp_method =
|
|
58
|
-
h =
|
|
59
|
-
distribution =
|
|
55
|
+
fs = self.settings[params[0]]
|
|
56
|
+
n_samples = fs['n_samples']
|
|
57
|
+
jvp_method = fs['jvp_method']
|
|
58
|
+
h = fs['h']
|
|
59
|
+
distribution = fs['distribution']
|
|
60
60
|
default = [None]*n_samples
|
|
61
61
|
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
62
62
|
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
@@ -74,10 +74,10 @@ class ForwardGradient(RandomizedFDM):
|
|
|
74
74
|
loss, d = jvp(partial(closure, False), params=params, tangent=prt)
|
|
75
75
|
|
|
76
76
|
elif jvp_method == 'forward':
|
|
77
|
-
loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss,
|
|
77
|
+
loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, h=h)
|
|
78
78
|
|
|
79
79
|
elif jvp_method == 'central':
|
|
80
|
-
loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt,
|
|
80
|
+
loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, h=h)
|
|
81
81
|
|
|
82
82
|
else: raise ValueError(jvp_method)
|
|
83
83
|
|
|
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Objective
|
|
9
9
|
|
|
10
10
|
GradTarget = Literal['update', 'grad', 'closure']
|
|
11
11
|
_Scalar = torch.Tensor | float
|
|
@@ -62,24 +62,25 @@ class GradApproximator(Module, ABC):
|
|
|
62
62
|
return spsa_grads, None, loss_plus
|
|
63
63
|
```
|
|
64
64
|
"""
|
|
65
|
-
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
65
|
+
def __init__(self, defaults: dict[str, Any] | None = None, return_approx_loss:bool=False, target: GradTarget = 'closure'):
|
|
66
66
|
super().__init__(defaults)
|
|
67
67
|
self._target: GradTarget = target
|
|
68
|
+
self._return_approx_loss = return_approx_loss
|
|
68
69
|
|
|
69
70
|
@abstractmethod
|
|
70
71
|
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
|
|
71
72
|
"""Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
|
|
72
73
|
|
|
73
|
-
def pre_step(self,
|
|
74
|
+
def pre_step(self, objective: Objective) -> None:
|
|
74
75
|
"""This runs once before each step, whereas `approximate` may run multiple times per step if further modules
|
|
75
76
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
76
77
|
|
|
77
78
|
@torch.no_grad
|
|
78
|
-
def
|
|
79
|
-
self.pre_step(
|
|
79
|
+
def update(self, objective):
|
|
80
|
+
self.pre_step(objective)
|
|
80
81
|
|
|
81
|
-
if
|
|
82
|
-
params, closure, loss =
|
|
82
|
+
if objective.closure is None: raise RuntimeError("Gradient approximation requires closure")
|
|
83
|
+
params, closure, loss = objective.params, objective.closure, objective.loss
|
|
83
84
|
|
|
84
85
|
if self._target == 'closure':
|
|
85
86
|
|
|
@@ -88,20 +89,26 @@ class GradApproximator(Module, ABC):
|
|
|
88
89
|
# set loss to None because closure might be evaluated at different points
|
|
89
90
|
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
|
|
90
91
|
for p, g in zip(params, grad): p.grad = g
|
|
91
|
-
|
|
92
|
+
if l is not None: return l
|
|
93
|
+
if self._return_approx_loss and l_approx is not None: return l_approx
|
|
94
|
+
return closure(False)
|
|
95
|
+
|
|
92
96
|
return closure(False)
|
|
93
97
|
|
|
94
|
-
|
|
95
|
-
return
|
|
98
|
+
objective.closure = approx_closure
|
|
99
|
+
return
|
|
96
100
|
|
|
97
101
|
# if var.grad is not None:
|
|
98
102
|
# warnings.warn('Using grad approximator when `var.grad` is already set.')
|
|
99
|
-
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
|
|
100
|
-
if loss_approx is not None:
|
|
101
|
-
if loss is not None:
|
|
102
|
-
if self._target == 'grad':
|
|
103
|
-
elif self._target == 'update':
|
|
103
|
+
grad, loss, loss_approx = self.approximate(closure=closure, params=params, loss=loss)
|
|
104
|
+
if loss_approx is not None: objective.loss_approx = loss_approx
|
|
105
|
+
if loss is not None: objective.loss = objective.loss_approx = loss
|
|
106
|
+
if self._target == 'grad': objective.grads = list(grad)
|
|
107
|
+
elif self._target == 'update': objective.updates = list(grad)
|
|
104
108
|
else: raise ValueError(self._target)
|
|
105
|
-
return
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
def apply(self, objective):
|
|
112
|
+
return objective
|
|
106
113
|
|
|
107
114
|
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']
|
|
@@ -176,7 +176,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
176
176
|
```py
|
|
177
177
|
spsa = tz.Modular(
|
|
178
178
|
model.parameters(),
|
|
179
|
-
tz.m.RandomizedFDM(formula="
|
|
179
|
+
tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
|
|
180
180
|
tz.m.LR(1e-2)
|
|
181
181
|
)
|
|
182
182
|
```
|
|
@@ -187,7 +187,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
187
187
|
```
|
|
188
188
|
rdsa = tz.Modular(
|
|
189
189
|
model.parameters(),
|
|
190
|
-
tz.m.RandomizedFDM(formula="
|
|
190
|
+
tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
|
|
191
191
|
tz.m.LR(1e-2)
|
|
192
192
|
)
|
|
193
193
|
```
|
|
@@ -223,23 +223,24 @@ class RandomizedFDM(GradApproximator):
|
|
|
223
223
|
n_samples: int = 1,
|
|
224
224
|
formula: _FD_Formula = "central",
|
|
225
225
|
distribution: Distributions = "rademacher",
|
|
226
|
-
pre_generate = True,
|
|
226
|
+
pre_generate: bool = True,
|
|
227
|
+
return_approx_loss: bool = False,
|
|
227
228
|
seed: int | None | torch.Generator = None,
|
|
228
229
|
target: GradTarget = "closure",
|
|
229
230
|
):
|
|
230
231
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
|
|
231
|
-
super().__init__(defaults, target=target)
|
|
232
|
+
super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
|
|
232
233
|
|
|
233
234
|
|
|
234
|
-
def pre_step(self,
|
|
235
|
-
h = self.get_settings(
|
|
235
|
+
def pre_step(self, objective):
|
|
236
|
+
h = self.get_settings(objective.params, 'h')
|
|
236
237
|
pre_generate = self.defaults['pre_generate']
|
|
237
238
|
|
|
238
239
|
if pre_generate:
|
|
239
240
|
n_samples = self.defaults['n_samples']
|
|
240
241
|
distribution = self.defaults['distribution']
|
|
241
242
|
|
|
242
|
-
params = TensorList(
|
|
243
|
+
params = TensorList(objective.params)
|
|
243
244
|
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
244
245
|
perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
|
|
245
246
|
|
|
@@ -346,11 +347,12 @@ class RDSA(RandomizedFDM):
|
|
|
346
347
|
n_samples: int = 1,
|
|
347
348
|
formula: _FD_Formula = "central2",
|
|
348
349
|
distribution: Distributions = "gaussian",
|
|
349
|
-
pre_generate = True,
|
|
350
|
+
pre_generate: bool = True,
|
|
351
|
+
return_approx_loss: bool = False,
|
|
350
352
|
target: GradTarget = "closure",
|
|
351
353
|
seed: int | None | torch.Generator = None,
|
|
352
354
|
):
|
|
353
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
355
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
|
|
354
356
|
|
|
355
357
|
class GaussianSmoothing(RandomizedFDM):
|
|
356
358
|
"""
|
|
@@ -380,11 +382,12 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
380
382
|
n_samples: int = 100,
|
|
381
383
|
formula: _FD_Formula = "forward2",
|
|
382
384
|
distribution: Distributions = "gaussian",
|
|
383
|
-
pre_generate = True,
|
|
385
|
+
pre_generate: bool = True,
|
|
386
|
+
return_approx_loss: bool = False,
|
|
384
387
|
target: GradTarget = "closure",
|
|
385
388
|
seed: int | None | torch.Generator = None,
|
|
386
389
|
):
|
|
387
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
390
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
|
|
388
391
|
|
|
389
392
|
class MeZO(GradApproximator):
|
|
390
393
|
"""Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
|
|
@@ -406,10 +409,10 @@ class MeZO(GradApproximator):
|
|
|
406
409
|
"""
|
|
407
410
|
|
|
408
411
|
def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
|
|
409
|
-
distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
|
|
412
|
+
distribution: Distributions = 'rademacher', return_approx_loss: bool = False, target: GradTarget = 'closure'):
|
|
410
413
|
|
|
411
414
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
|
|
412
|
-
super().__init__(defaults, target=target)
|
|
415
|
+
super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
|
|
413
416
|
|
|
414
417
|
def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
|
|
415
418
|
prt = TensorList(params).sample_like(
|
|
@@ -419,19 +422,19 @@ class MeZO(GradApproximator):
|
|
|
419
422
|
)
|
|
420
423
|
return prt
|
|
421
424
|
|
|
422
|
-
def pre_step(self,
|
|
423
|
-
h = NumberList(self.settings[p]['h'] for p in
|
|
425
|
+
def pre_step(self, objective):
|
|
426
|
+
h = NumberList(self.settings[p]['h'] for p in objective.params)
|
|
424
427
|
|
|
425
428
|
n_samples = self.defaults['n_samples']
|
|
426
429
|
distribution = self.defaults['distribution']
|
|
427
430
|
|
|
428
|
-
step =
|
|
431
|
+
step = objective.current_step
|
|
429
432
|
|
|
430
433
|
# create functions that generate a deterministic perturbation from seed based on current step
|
|
431
434
|
prt_fns = []
|
|
432
435
|
for i in range(n_samples):
|
|
433
436
|
|
|
434
|
-
prt_fn = partial(self._seeded_perturbation, params=
|
|
437
|
+
prt_fn = partial(self._seeded_perturbation, params=objective.params, distribution=distribution, seed=1_000_000*step + i, h=h)
|
|
435
438
|
prt_fns.append(prt_fn)
|
|
436
439
|
|
|
437
440
|
self.global_state['prt_fns'] = prt_fns
|
|
@@ -1,28 +1,31 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from ...core import Module
|
|
3
2
|
|
|
4
|
-
from ...
|
|
3
|
+
from ...core import Chainable, Module, step
|
|
4
|
+
from ...linalg import linear_operator
|
|
5
5
|
from ...utils import vec_to_tensors
|
|
6
|
-
from ...utils.
|
|
6
|
+
from ...utils.derivatives import flatten_jacobian, jacobian_wrt
|
|
7
|
+
|
|
8
|
+
|
|
7
9
|
class SumOfSquares(Module):
|
|
8
10
|
"""Sets loss to be the sum of squares of values returned by the closure.
|
|
9
11
|
|
|
10
12
|
This is meant to be used to test least squares methods against ordinary minimization methods.
|
|
11
13
|
|
|
12
14
|
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
13
|
-
Please add the
|
|
15
|
+
Please add the ``backward`` argument, it will always be False but it is required.
|
|
14
16
|
"""
|
|
15
17
|
def __init__(self):
|
|
16
18
|
super().__init__()
|
|
17
19
|
|
|
18
20
|
@torch.no_grad
|
|
19
|
-
def
|
|
20
|
-
closure =
|
|
21
|
+
def update(self, objective):
|
|
22
|
+
closure = objective.closure
|
|
21
23
|
|
|
22
24
|
if closure is not None:
|
|
25
|
+
|
|
23
26
|
def sos_closure(backward=True):
|
|
24
27
|
if backward:
|
|
25
|
-
|
|
28
|
+
objective.zero_grad()
|
|
26
29
|
with torch.enable_grad():
|
|
27
30
|
loss = closure(False)
|
|
28
31
|
loss = loss.pow(2).sum()
|
|
@@ -32,16 +35,13 @@ class SumOfSquares(Module):
|
|
|
32
35
|
loss = closure(False)
|
|
33
36
|
return loss.pow(2).sum()
|
|
34
37
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
if var.loss is not None:
|
|
38
|
-
var.loss = var.loss.pow(2).sum()
|
|
38
|
+
objective.closure = sos_closure
|
|
39
39
|
|
|
40
|
-
if
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return var
|
|
40
|
+
if objective.loss is not None:
|
|
41
|
+
objective.loss = objective.loss.pow(2).sum()
|
|
44
42
|
|
|
43
|
+
if objective.loss_approx is not None:
|
|
44
|
+
objective.loss_approx = objective.loss_approx.pow(2).sum()
|
|
45
45
|
|
|
46
46
|
class GaussNewton(Module):
|
|
47
47
|
"""Gauss-newton method.
|
|
@@ -101,35 +101,45 @@ class GaussNewton(Module):
|
|
|
101
101
|
print(f'{losses.mean() = }')
|
|
102
102
|
```
|
|
103
103
|
"""
|
|
104
|
-
def __init__(self, reg:float = 1e-8, batched:bool=True, ):
|
|
104
|
+
def __init__(self, reg:float = 1e-8, batched:bool=True, inner: Chainable | None = None):
|
|
105
105
|
super().__init__(defaults=dict(batched=batched, reg=reg))
|
|
106
|
+
if inner is not None: self.set_child('inner', inner)
|
|
106
107
|
|
|
107
108
|
@torch.no_grad
|
|
108
|
-
def update(self,
|
|
109
|
-
params =
|
|
109
|
+
def update(self, objective):
|
|
110
|
+
params = objective.params
|
|
110
111
|
batched = self.defaults['batched']
|
|
111
112
|
|
|
112
|
-
closure =
|
|
113
|
+
closure = objective.closure
|
|
113
114
|
assert closure is not None
|
|
114
115
|
|
|
115
116
|
# gauss newton direction
|
|
116
117
|
with torch.enable_grad():
|
|
117
|
-
|
|
118
|
-
assert isinstance(
|
|
119
|
-
|
|
118
|
+
r = objective.get_loss(backward=False) # nresiduals
|
|
119
|
+
assert isinstance(r, torch.Tensor)
|
|
120
|
+
J_list = jacobian_wrt([r.ravel()], params, batched=batched)
|
|
121
|
+
|
|
122
|
+
objective.loss = r.pow(2).sum()
|
|
123
|
+
|
|
124
|
+
J = self.global_state["J"] = flatten_jacobian(J_list) # (nresiduals, ndim)
|
|
125
|
+
Jr = J.T @ r.detach() # (ndim)
|
|
126
|
+
|
|
127
|
+
# if there are more residuals, solve (J^T J)x = J^T r, so we need Jr
|
|
128
|
+
# otherwise solve (J J^T)z = r and set x = J^T z, so we need r
|
|
129
|
+
nresiduals, ndim = J.shape
|
|
130
|
+
if nresiduals >= ndim or "inner" in self.children:
|
|
131
|
+
self.global_state["Jr"] = Jr
|
|
120
132
|
|
|
121
|
-
|
|
133
|
+
else:
|
|
134
|
+
self.global_state["r"] = r
|
|
122
135
|
|
|
123
|
-
|
|
124
|
-
Gtf = G.T @ f.detach() # (ndim)
|
|
125
|
-
self.global_state["Gtf"] = Gtf
|
|
126
|
-
var.grad = vec_to_tensors(Gtf, var.params)
|
|
136
|
+
objective.grads = vec_to_tensors(Jr, objective.params)
|
|
127
137
|
|
|
128
138
|
# set closure to calculate sum of squares for line searches etc
|
|
129
|
-
if
|
|
139
|
+
if objective.closure is not None:
|
|
130
140
|
def sos_closure(backward=True):
|
|
131
141
|
if backward:
|
|
132
|
-
|
|
142
|
+
objective.zero_grad()
|
|
133
143
|
with torch.enable_grad():
|
|
134
144
|
loss = closure(False).pow(2).sum()
|
|
135
145
|
loss.backward()
|
|
@@ -138,24 +148,62 @@ class GaussNewton(Module):
|
|
|
138
148
|
loss = closure(False).pow(2).sum()
|
|
139
149
|
return loss
|
|
140
150
|
|
|
141
|
-
|
|
151
|
+
objective.closure = sos_closure
|
|
142
152
|
|
|
143
153
|
@torch.no_grad
|
|
144
|
-
def apply(self,
|
|
154
|
+
def apply(self, objective):
|
|
145
155
|
reg = self.defaults['reg']
|
|
146
156
|
|
|
147
|
-
|
|
148
|
-
|
|
157
|
+
J: torch.Tensor = self.global_state['J']
|
|
158
|
+
nresiduals, ndim = J.shape
|
|
159
|
+
if nresiduals >= ndim or "inner" in self.children:
|
|
160
|
+
|
|
161
|
+
# (J^T J)v = J^T r
|
|
162
|
+
Jr: torch.Tensor = self.global_state['Jr']
|
|
163
|
+
|
|
164
|
+
# inner step
|
|
165
|
+
if "inner" in self.children:
|
|
166
|
+
|
|
167
|
+
# var.grad is set to unflattened Jr
|
|
168
|
+
assert objective.grads is not None
|
|
169
|
+
objective = self.inner_step("inner", objective, must_exist=True)
|
|
170
|
+
Jr_list = objective.get_updates()
|
|
171
|
+
Jr = torch.cat([t.ravel() for t in Jr_list])
|
|
172
|
+
|
|
173
|
+
JJ = J.T @ J # (ndim, ndim)
|
|
174
|
+
if reg != 0:
|
|
175
|
+
JJ.add_(torch.eye(JJ.size(0), device=JJ.device, dtype=JJ.dtype).mul_(reg))
|
|
176
|
+
|
|
177
|
+
if nresiduals >= ndim:
|
|
178
|
+
v, info = torch.linalg.solve_ex(JJ, Jr) # pylint:disable=not-callable
|
|
179
|
+
else:
|
|
180
|
+
v = torch.linalg.lstsq(JJ, Jr).solution # pylint:disable=not-callable
|
|
181
|
+
|
|
182
|
+
objective.updates = vec_to_tensors(v, objective.params)
|
|
183
|
+
return objective
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
# solve (J J^T)z = r and set v = J^T z
|
|
187
|
+
# derivation
|
|
188
|
+
# we need (J^T J)v = J^T r
|
|
189
|
+
# suppose z is solution to (G G^T)z = r, and v = J^T z
|
|
190
|
+
# if v = J^T z, then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
|
|
191
|
+
# therefore with our presuppositions (J^T J)v = J^T r
|
|
192
|
+
|
|
193
|
+
# also this gives a minimum norm solution
|
|
194
|
+
|
|
195
|
+
r = self.global_state['r']
|
|
149
196
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
197
|
+
JJT = J @ J.T # (nresiduals, nresiduals)
|
|
198
|
+
if reg != 0:
|
|
199
|
+
JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
|
|
153
200
|
|
|
154
|
-
|
|
201
|
+
z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
|
|
202
|
+
v = J.T @ z
|
|
155
203
|
|
|
156
|
-
|
|
157
|
-
|
|
204
|
+
objective.updates = vec_to_tensors(v, objective.params)
|
|
205
|
+
return objective
|
|
158
206
|
|
|
159
|
-
def get_H(self,
|
|
160
|
-
|
|
161
|
-
return linear_operator.AtA(
|
|
207
|
+
def get_H(self, objective=...):
|
|
208
|
+
J = self.global_state['J']
|
|
209
|
+
return linear_operator.AtA(J)
|
|
@@ -117,7 +117,7 @@ class Backtracking(LineSearchBase):
|
|
|
117
117
|
|
|
118
118
|
# # directional derivative
|
|
119
119
|
if c == 0: d = 0
|
|
120
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.
|
|
120
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), var.get_updates()))
|
|
121
121
|
|
|
122
122
|
# scale init
|
|
123
123
|
init_scale = self.global_state.get('init_scale', 1)
|
|
@@ -199,7 +199,7 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
199
199
|
|
|
200
200
|
# directional derivative (0 if c = 0 because it is not needed)
|
|
201
201
|
if c == 0: d = 0
|
|
202
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.
|
|
202
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), update))
|
|
203
203
|
|
|
204
204
|
# scale beta
|
|
205
205
|
beta = beta * self.global_state['beta_scale']
|