torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -38,15 +38,15 @@ class SPSA1(GradApproximator):
|
|
|
38
38
|
super().__init__(defaults, target=target)
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def pre_step(self,
|
|
41
|
+
def pre_step(self, objective):
|
|
42
42
|
|
|
43
43
|
if self.defaults['pre_generate']:
|
|
44
44
|
|
|
45
|
-
params = TensorList(
|
|
45
|
+
params = TensorList(objective.params)
|
|
46
46
|
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
47
47
|
|
|
48
48
|
n_samples = self.defaults['n_samples']
|
|
49
|
-
h = self.get_settings(
|
|
49
|
+
h = self.get_settings(objective.params, 'h')
|
|
50
50
|
|
|
51
51
|
perturbations = [params.rademacher_like(generator=generator) for _ in range(n_samples)]
|
|
52
52
|
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
2
|
import torch
|
|
5
3
|
|
|
6
4
|
from ...core import Chainable
|
|
7
|
-
from ...utils import vec_to_tensors
|
|
8
|
-
from ..adaptive.shampoo import _merge_small_dims
|
|
5
|
+
from ...utils import vec_to_tensors
|
|
9
6
|
from ..projections import ProjectionBase
|
|
10
7
|
|
|
11
8
|
|
|
@@ -106,12 +106,12 @@ class FDM(GradApproximator):
|
|
|
106
106
|
plain FDM:
|
|
107
107
|
|
|
108
108
|
```python
|
|
109
|
-
fdm = tz.
|
|
109
|
+
fdm = tz.Optimizer(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
|
|
110
110
|
```
|
|
111
111
|
|
|
112
112
|
Any gradient-based method can use FDM-estimated gradients.
|
|
113
113
|
```python
|
|
114
|
-
fdm_ncg = tz.
|
|
114
|
+
fdm_ncg = tz.Optimizer(
|
|
115
115
|
model.parameters(),
|
|
116
116
|
tz.m.FDM(),
|
|
117
117
|
# set hvp_method to "forward" so that it
|
|
@@ -52,11 +52,11 @@ class ForwardGradient(RandomizedFDM):
|
|
|
52
52
|
params = TensorList(params)
|
|
53
53
|
loss_approx = None
|
|
54
54
|
|
|
55
|
-
|
|
56
|
-
n_samples =
|
|
57
|
-
jvp_method =
|
|
58
|
-
h =
|
|
59
|
-
distribution =
|
|
55
|
+
fs = self.settings[params[0]]
|
|
56
|
+
n_samples = fs['n_samples']
|
|
57
|
+
jvp_method = fs['jvp_method']
|
|
58
|
+
h = fs['h']
|
|
59
|
+
distribution = fs['distribution']
|
|
60
60
|
default = [None]*n_samples
|
|
61
61
|
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
62
62
|
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
@@ -74,10 +74,10 @@ class ForwardGradient(RandomizedFDM):
|
|
|
74
74
|
loss, d = jvp(partial(closure, False), params=params, tangent=prt)
|
|
75
75
|
|
|
76
76
|
elif jvp_method == 'forward':
|
|
77
|
-
loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss,
|
|
77
|
+
loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, h=h)
|
|
78
78
|
|
|
79
79
|
elif jvp_method == 'central':
|
|
80
|
-
loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt,
|
|
80
|
+
loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, h=h)
|
|
81
81
|
|
|
82
82
|
else: raise ValueError(jvp_method)
|
|
83
83
|
|
|
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Objective
|
|
9
9
|
|
|
10
10
|
GradTarget = Literal['update', 'grad', 'closure']
|
|
11
11
|
_Scalar = torch.Tensor | float
|
|
@@ -62,24 +62,25 @@ class GradApproximator(Module, ABC):
|
|
|
62
62
|
return spsa_grads, None, loss_plus
|
|
63
63
|
```
|
|
64
64
|
"""
|
|
65
|
-
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
65
|
+
def __init__(self, defaults: dict[str, Any] | None = None, return_approx_loss:bool=False, target: GradTarget = 'closure'):
|
|
66
66
|
super().__init__(defaults)
|
|
67
67
|
self._target: GradTarget = target
|
|
68
|
+
self._return_approx_loss = return_approx_loss
|
|
68
69
|
|
|
69
70
|
@abstractmethod
|
|
70
71
|
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
|
|
71
72
|
"""Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
|
|
72
73
|
|
|
73
|
-
def pre_step(self,
|
|
74
|
+
def pre_step(self, objective: Objective) -> None:
|
|
74
75
|
"""This runs once before each step, whereas `approximate` may run multiple times per step if further modules
|
|
75
76
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
76
77
|
|
|
77
78
|
@torch.no_grad
|
|
78
|
-
def
|
|
79
|
-
self.pre_step(
|
|
79
|
+
def update(self, objective):
|
|
80
|
+
self.pre_step(objective)
|
|
80
81
|
|
|
81
|
-
if
|
|
82
|
-
params, closure, loss =
|
|
82
|
+
if objective.closure is None: raise RuntimeError("Gradient approximation requires closure")
|
|
83
|
+
params, closure, loss = objective.params, objective.closure, objective.loss
|
|
83
84
|
|
|
84
85
|
if self._target == 'closure':
|
|
85
86
|
|
|
@@ -88,20 +89,26 @@ class GradApproximator(Module, ABC):
|
|
|
88
89
|
# set loss to None because closure might be evaluated at different points
|
|
89
90
|
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
|
|
90
91
|
for p, g in zip(params, grad): p.grad = g
|
|
91
|
-
|
|
92
|
+
if l is not None: return l
|
|
93
|
+
if self._return_approx_loss and l_approx is not None: return l_approx
|
|
94
|
+
return closure(False)
|
|
95
|
+
|
|
92
96
|
return closure(False)
|
|
93
97
|
|
|
94
|
-
|
|
95
|
-
return
|
|
98
|
+
objective.closure = approx_closure
|
|
99
|
+
return
|
|
96
100
|
|
|
97
101
|
# if var.grad is not None:
|
|
98
102
|
# warnings.warn('Using grad approximator when `var.grad` is already set.')
|
|
99
|
-
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
|
|
100
|
-
if loss_approx is not None:
|
|
101
|
-
if loss is not None:
|
|
102
|
-
if self._target == 'grad':
|
|
103
|
-
elif self._target == 'update':
|
|
103
|
+
grad, loss, loss_approx = self.approximate(closure=closure, params=params, loss=loss)
|
|
104
|
+
if loss_approx is not None: objective.loss_approx = loss_approx
|
|
105
|
+
if loss is not None: objective.loss = objective.loss_approx = loss
|
|
106
|
+
if self._target == 'grad': objective.grads = list(grad)
|
|
107
|
+
elif self._target == 'update': objective.updates = list(grad)
|
|
104
108
|
else: raise ValueError(self._target)
|
|
105
|
-
return
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
def apply(self, objective):
|
|
112
|
+
return objective
|
|
106
113
|
|
|
107
114
|
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']
|
|
@@ -174,9 +174,9 @@ class RandomizedFDM(GradApproximator):
|
|
|
174
174
|
|
|
175
175
|
SPSA is randomized FDM with rademacher distribution and central formula.
|
|
176
176
|
```py
|
|
177
|
-
spsa = tz.
|
|
177
|
+
spsa = tz.Optimizer(
|
|
178
178
|
model.parameters(),
|
|
179
|
-
tz.m.RandomizedFDM(formula="
|
|
179
|
+
tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
|
|
180
180
|
tz.m.LR(1e-2)
|
|
181
181
|
)
|
|
182
182
|
```
|
|
@@ -185,9 +185,9 @@ class RandomizedFDM(GradApproximator):
|
|
|
185
185
|
|
|
186
186
|
RDSA is randomized FDM with usually gaussian distribution and central formula.
|
|
187
187
|
```
|
|
188
|
-
rdsa = tz.
|
|
188
|
+
rdsa = tz.Optimizer(
|
|
189
189
|
model.parameters(),
|
|
190
|
-
tz.m.RandomizedFDM(formula="
|
|
190
|
+
tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
|
|
191
191
|
tz.m.LR(1e-2)
|
|
192
192
|
)
|
|
193
193
|
```
|
|
@@ -196,7 +196,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
196
196
|
|
|
197
197
|
GS uses many gaussian samples with possibly a larger finite difference step size.
|
|
198
198
|
```
|
|
199
|
-
gs = tz.
|
|
199
|
+
gs = tz.Optimizer(
|
|
200
200
|
model.parameters(),
|
|
201
201
|
tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
|
|
202
202
|
tz.m.NewtonCG(hvp_method="forward"),
|
|
@@ -208,7 +208,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
208
208
|
|
|
209
209
|
Momentum might help by reducing the variance of the estimated gradients.
|
|
210
210
|
```
|
|
211
|
-
momentum_spsa = tz.
|
|
211
|
+
momentum_spsa = tz.Optimizer(
|
|
212
212
|
model.parameters(),
|
|
213
213
|
tz.m.RandomizedFDM(),
|
|
214
214
|
tz.m.HeavyBall(0.9),
|
|
@@ -223,23 +223,24 @@ class RandomizedFDM(GradApproximator):
|
|
|
223
223
|
n_samples: int = 1,
|
|
224
224
|
formula: _FD_Formula = "central",
|
|
225
225
|
distribution: Distributions = "rademacher",
|
|
226
|
-
pre_generate = True,
|
|
226
|
+
pre_generate: bool = True,
|
|
227
|
+
return_approx_loss: bool = False,
|
|
227
228
|
seed: int | None | torch.Generator = None,
|
|
228
229
|
target: GradTarget = "closure",
|
|
229
230
|
):
|
|
230
231
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
|
|
231
|
-
super().__init__(defaults, target=target)
|
|
232
|
+
super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
|
|
232
233
|
|
|
233
234
|
|
|
234
|
-
def pre_step(self,
|
|
235
|
-
h = self.get_settings(
|
|
235
|
+
def pre_step(self, objective):
|
|
236
|
+
h = self.get_settings(objective.params, 'h')
|
|
236
237
|
pre_generate = self.defaults['pre_generate']
|
|
237
238
|
|
|
238
239
|
if pre_generate:
|
|
239
240
|
n_samples = self.defaults['n_samples']
|
|
240
241
|
distribution = self.defaults['distribution']
|
|
241
242
|
|
|
242
|
-
params = TensorList(
|
|
243
|
+
params = TensorList(objective.params)
|
|
243
244
|
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
244
245
|
perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
|
|
245
246
|
|
|
@@ -346,11 +347,12 @@ class RDSA(RandomizedFDM):
|
|
|
346
347
|
n_samples: int = 1,
|
|
347
348
|
formula: _FD_Formula = "central2",
|
|
348
349
|
distribution: Distributions = "gaussian",
|
|
349
|
-
pre_generate = True,
|
|
350
|
+
pre_generate: bool = True,
|
|
351
|
+
return_approx_loss: bool = False,
|
|
350
352
|
target: GradTarget = "closure",
|
|
351
353
|
seed: int | None | torch.Generator = None,
|
|
352
354
|
):
|
|
353
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
355
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
|
|
354
356
|
|
|
355
357
|
class GaussianSmoothing(RandomizedFDM):
|
|
356
358
|
"""
|
|
@@ -380,11 +382,12 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
380
382
|
n_samples: int = 100,
|
|
381
383
|
formula: _FD_Formula = "forward2",
|
|
382
384
|
distribution: Distributions = "gaussian",
|
|
383
|
-
pre_generate = True,
|
|
385
|
+
pre_generate: bool = True,
|
|
386
|
+
return_approx_loss: bool = False,
|
|
384
387
|
target: GradTarget = "closure",
|
|
385
388
|
seed: int | None | torch.Generator = None,
|
|
386
389
|
):
|
|
387
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
390
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)
|
|
388
391
|
|
|
389
392
|
class MeZO(GradApproximator):
|
|
390
393
|
"""Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
|
|
@@ -406,10 +409,10 @@ class MeZO(GradApproximator):
|
|
|
406
409
|
"""
|
|
407
410
|
|
|
408
411
|
def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
|
|
409
|
-
distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
|
|
412
|
+
distribution: Distributions = 'rademacher', return_approx_loss: bool = False, target: GradTarget = 'closure'):
|
|
410
413
|
|
|
411
414
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
|
|
412
|
-
super().__init__(defaults, target=target)
|
|
415
|
+
super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)
|
|
413
416
|
|
|
414
417
|
def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
|
|
415
418
|
prt = TensorList(params).sample_like(
|
|
@@ -419,19 +422,19 @@ class MeZO(GradApproximator):
|
|
|
419
422
|
)
|
|
420
423
|
return prt
|
|
421
424
|
|
|
422
|
-
def pre_step(self,
|
|
423
|
-
h = NumberList(self.settings[p]['h'] for p in
|
|
425
|
+
def pre_step(self, objective):
|
|
426
|
+
h = NumberList(self.settings[p]['h'] for p in objective.params)
|
|
424
427
|
|
|
425
428
|
n_samples = self.defaults['n_samples']
|
|
426
429
|
distribution = self.defaults['distribution']
|
|
427
430
|
|
|
428
|
-
step =
|
|
431
|
+
step = objective.current_step
|
|
429
432
|
|
|
430
433
|
# create functions that generate a deterministic perturbation from seed based on current step
|
|
431
434
|
prt_fns = []
|
|
432
435
|
for i in range(n_samples):
|
|
433
436
|
|
|
434
|
-
prt_fn = partial(self._seeded_perturbation, params=
|
|
437
|
+
prt_fn = partial(self._seeded_perturbation, params=objective.params, distribution=distribution, seed=1_000_000*step + i, h=h)
|
|
435
438
|
prt_fns.append(prt_fn)
|
|
436
439
|
|
|
437
440
|
self.global_state['prt_fns'] = prt_fns
|
|
@@ -1,28 +1,31 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from ...core import Module
|
|
3
2
|
|
|
4
|
-
from ...
|
|
3
|
+
from ...core import Chainable, Transform
|
|
4
|
+
from ...linalg import linear_operator
|
|
5
5
|
from ...utils import vec_to_tensors
|
|
6
|
-
from ...utils.
|
|
7
|
-
|
|
6
|
+
from ...utils.derivatives import flatten_jacobian, jacobian_wrt
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SumOfSquares(Transform):
|
|
8
10
|
"""Sets loss to be the sum of squares of values returned by the closure.
|
|
9
11
|
|
|
10
12
|
This is meant to be used to test least squares methods against ordinary minimization methods.
|
|
11
13
|
|
|
12
14
|
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
13
|
-
Please add the
|
|
15
|
+
Please add the ``backward`` argument, it will always be False but it is required.
|
|
14
16
|
"""
|
|
15
17
|
def __init__(self):
|
|
16
18
|
super().__init__()
|
|
17
19
|
|
|
18
20
|
@torch.no_grad
|
|
19
|
-
def
|
|
20
|
-
closure =
|
|
21
|
+
def update_states(self, objective, states, settings):
|
|
22
|
+
closure = objective.closure
|
|
21
23
|
|
|
22
24
|
if closure is not None:
|
|
25
|
+
|
|
23
26
|
def sos_closure(backward=True):
|
|
24
27
|
if backward:
|
|
25
|
-
|
|
28
|
+
objective.zero_grad()
|
|
26
29
|
with torch.enable_grad():
|
|
27
30
|
loss = closure(False)
|
|
28
31
|
loss = loss.pow(2).sum()
|
|
@@ -32,18 +35,19 @@ class SumOfSquares(Module):
|
|
|
32
35
|
loss = closure(False)
|
|
33
36
|
return loss.pow(2).sum()
|
|
34
37
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
if var.loss is not None:
|
|
38
|
-
var.loss = var.loss.pow(2).sum()
|
|
38
|
+
objective.closure = sos_closure
|
|
39
39
|
|
|
40
|
-
if
|
|
41
|
-
|
|
40
|
+
if objective.loss is not None:
|
|
41
|
+
objective.loss = objective.loss.pow(2).sum()
|
|
42
42
|
|
|
43
|
-
|
|
43
|
+
if objective.loss_approx is not None:
|
|
44
|
+
objective.loss_approx = objective.loss_approx.pow(2).sum()
|
|
44
45
|
|
|
46
|
+
@torch.no_grad
|
|
47
|
+
def apply_states(self, objective, states, settings):
|
|
48
|
+
return objective
|
|
45
49
|
|
|
46
|
-
class GaussNewton(
|
|
50
|
+
class GaussNewton(Transform):
|
|
47
51
|
"""Gauss-newton method.
|
|
48
52
|
|
|
49
53
|
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
@@ -57,6 +61,9 @@ class GaussNewton(Module):
|
|
|
57
61
|
|
|
58
62
|
Args:
|
|
59
63
|
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
64
|
+
update_freq (int, optional):
|
|
65
|
+
frequency of computing the jacobian. When jacobian is not computed, only residuals are computed and updated.
|
|
66
|
+
Defaults to 1.
|
|
60
67
|
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
61
68
|
|
|
62
69
|
Examples:
|
|
@@ -68,7 +75,7 @@ class GaussNewton(Module):
|
|
|
68
75
|
return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
|
|
69
76
|
|
|
70
77
|
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
71
|
-
opt = tz.
|
|
78
|
+
opt = tz.Optimizer([X], tz.m.GaussNewton(), tz.m.Backtracking())
|
|
72
79
|
|
|
73
80
|
# define the closure for line search
|
|
74
81
|
def closure(backward=True):
|
|
@@ -86,7 +93,7 @@ class GaussNewton(Module):
|
|
|
86
93
|
y = torch.randn(64, 10)
|
|
87
94
|
|
|
88
95
|
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
89
|
-
opt = tz.
|
|
96
|
+
opt = tz.Optimizer(
|
|
90
97
|
model.parameters(),
|
|
91
98
|
tz.m.TrustCG(tz.m.GaussNewton()),
|
|
92
99
|
)
|
|
@@ -101,35 +108,62 @@ class GaussNewton(Module):
|
|
|
101
108
|
print(f'{losses.mean() = }')
|
|
102
109
|
```
|
|
103
110
|
"""
|
|
104
|
-
def __init__(self, reg:float = 1e-8, batched:bool=True, ):
|
|
105
|
-
|
|
111
|
+
def __init__(self, reg:float = 1e-8, update_freq: int= 1, batched:bool=True, inner: Chainable | None = None):
|
|
112
|
+
defaults=dict(update_freq=update_freq,batched=batched, reg=reg)
|
|
113
|
+
super().__init__(defaults=defaults)
|
|
114
|
+
if inner is not None: self.set_child('inner', inner)
|
|
106
115
|
|
|
107
116
|
@torch.no_grad
|
|
108
|
-
def
|
|
109
|
-
|
|
110
|
-
|
|
117
|
+
def update_states(self, objective, states, settings):
|
|
118
|
+
fs = settings[0]
|
|
119
|
+
params = objective.params
|
|
120
|
+
closure = objective.closure
|
|
121
|
+
batched = fs['batched']
|
|
122
|
+
update_freq = fs['update_freq']
|
|
123
|
+
|
|
124
|
+
# compute residuals
|
|
125
|
+
r = objective.loss
|
|
126
|
+
if r is None:
|
|
127
|
+
assert closure is not None
|
|
128
|
+
with torch.enable_grad():
|
|
129
|
+
r = objective.get_loss(backward=False) # n_residuals
|
|
130
|
+
assert isinstance(r, torch.Tensor)
|
|
131
|
+
|
|
132
|
+
# set sum of squares scalar loss and it's gradient to objective
|
|
133
|
+
objective.loss = r.pow(2).sum()
|
|
134
|
+
|
|
135
|
+
step = self.increment_counter("step", start=0)
|
|
136
|
+
|
|
137
|
+
if step % update_freq == 0:
|
|
138
|
+
|
|
139
|
+
# compute jacobian
|
|
140
|
+
with torch.enable_grad():
|
|
141
|
+
J_list = jacobian_wrt([r.ravel()], params, batched=batched)
|
|
142
|
+
|
|
143
|
+
J = self.global_state["J"] = flatten_jacobian(J_list) # (n_residuals, ndim)
|
|
111
144
|
|
|
112
|
-
|
|
113
|
-
|
|
145
|
+
else:
|
|
146
|
+
J = self.global_state["J"]
|
|
114
147
|
|
|
115
|
-
|
|
116
|
-
with torch.enable_grad():
|
|
117
|
-
f = var.get_loss(backward=False) # n_out
|
|
118
|
-
assert isinstance(f, torch.Tensor)
|
|
119
|
-
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
148
|
+
Jr = J.T @ r.detach() # (ndim)
|
|
120
149
|
|
|
121
|
-
|
|
150
|
+
# if there are more residuals, solve (J^T J)x = J^T r, so we need Jr
|
|
151
|
+
# otherwise solve (J J^T)z = r and set x = J^T z, so we need r
|
|
152
|
+
n_residuals, ndim = J.shape
|
|
153
|
+
if n_residuals >= ndim or "inner" in self.children:
|
|
154
|
+
self.global_state["Jr"] = Jr
|
|
122
155
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
156
|
+
else:
|
|
157
|
+
self.global_state["r"] = r
|
|
158
|
+
|
|
159
|
+
objective.grads = vec_to_tensors(Jr, objective.params)
|
|
127
160
|
|
|
128
161
|
# set closure to calculate sum of squares for line searches etc
|
|
129
|
-
if
|
|
162
|
+
if closure is not None:
|
|
130
163
|
def sos_closure(backward=True):
|
|
164
|
+
|
|
131
165
|
if backward:
|
|
132
|
-
|
|
166
|
+
objective.zero_grad()
|
|
133
167
|
with torch.enable_grad():
|
|
134
168
|
loss = closure(False).pow(2).sum()
|
|
135
169
|
loss.backward()
|
|
@@ -138,24 +172,61 @@ class GaussNewton(Module):
|
|
|
138
172
|
loss = closure(False).pow(2).sum()
|
|
139
173
|
return loss
|
|
140
174
|
|
|
141
|
-
|
|
175
|
+
objective.closure = sos_closure
|
|
142
176
|
|
|
143
177
|
@torch.no_grad
|
|
144
|
-
def
|
|
145
|
-
|
|
178
|
+
def apply_states(self, objective, states, settings):
|
|
179
|
+
fs = settings[0]
|
|
180
|
+
reg = fs['reg']
|
|
181
|
+
|
|
182
|
+
J: torch.Tensor = self.global_state['J']
|
|
183
|
+
nresiduals, ndim = J.shape
|
|
184
|
+
if nresiduals >= ndim or "inner" in self.children:
|
|
185
|
+
|
|
186
|
+
# (J^T J)v = J^T r
|
|
187
|
+
Jr: torch.Tensor = self.global_state['Jr']
|
|
188
|
+
|
|
189
|
+
# inner step
|
|
190
|
+
if "inner" in self.children:
|
|
191
|
+
|
|
192
|
+
# var.grad is set to unflattened Jr
|
|
193
|
+
assert objective.grads is not None
|
|
194
|
+
objective = self.inner_step("inner", objective, must_exist=True)
|
|
195
|
+
Jr_list = objective.get_updates()
|
|
196
|
+
Jr = torch.cat([t.ravel() for t in Jr_list])
|
|
197
|
+
|
|
198
|
+
JtJ = J.T @ J # (ndim, ndim)
|
|
199
|
+
if reg != 0:
|
|
200
|
+
JtJ.add_(torch.eye(JtJ.size(0), device=JtJ.device, dtype=JtJ.dtype).mul_(reg))
|
|
201
|
+
|
|
202
|
+
if nresiduals >= ndim:
|
|
203
|
+
v, info = torch.linalg.solve_ex(JtJ, Jr) # pylint:disable=not-callable
|
|
204
|
+
else:
|
|
205
|
+
v = torch.linalg.lstsq(JtJ, Jr).solution # pylint:disable=not-callable
|
|
206
|
+
|
|
207
|
+
objective.updates = vec_to_tensors(v, objective.params)
|
|
208
|
+
return objective
|
|
209
|
+
|
|
210
|
+
# else:
|
|
211
|
+
# solve (J J^T)z = r and set v = J^T z
|
|
212
|
+
# we need (J^T J)v = J^T r
|
|
213
|
+
# if z is solution to (G G^T)z = r, and v = J^T z
|
|
214
|
+
# then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
|
|
215
|
+
# therefore (J^T J)v = J^T r
|
|
216
|
+
# also this gives a minimum norm solution
|
|
146
217
|
|
|
147
|
-
|
|
148
|
-
Gtf = self.global_state['Gtf']
|
|
218
|
+
r = self.global_state['r']
|
|
149
219
|
|
|
150
|
-
|
|
220
|
+
JJT = J @ J.T # (nresiduals, nresiduals)
|
|
151
221
|
if reg != 0:
|
|
152
|
-
|
|
222
|
+
JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))
|
|
153
223
|
|
|
154
|
-
|
|
224
|
+
z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
|
|
225
|
+
v = J.T @ z
|
|
155
226
|
|
|
156
|
-
|
|
157
|
-
return
|
|
227
|
+
objective.updates = vec_to_tensors(v, objective.params)
|
|
228
|
+
return objective
|
|
158
229
|
|
|
159
|
-
def get_H(self,
|
|
160
|
-
|
|
161
|
-
return linear_operator.AtA(
|
|
230
|
+
def get_H(self, objective=...):
|
|
231
|
+
J = self.global_state['J']
|
|
232
|
+
return linear_operator.AtA(J)
|
|
@@ -77,7 +77,7 @@ class Backtracking(LineSearchBase):
|
|
|
77
77
|
Gradient descent with backtracking line search:
|
|
78
78
|
|
|
79
79
|
```python
|
|
80
|
-
opt = tz.
|
|
80
|
+
opt = tz.Optimizer(
|
|
81
81
|
model.parameters(),
|
|
82
82
|
tz.m.Backtracking()
|
|
83
83
|
)
|
|
@@ -85,7 +85,7 @@ class Backtracking(LineSearchBase):
|
|
|
85
85
|
|
|
86
86
|
L-BFGS with backtracking line search:
|
|
87
87
|
```python
|
|
88
|
-
opt = tz.
|
|
88
|
+
opt = tz.Optimizer(
|
|
89
89
|
model.parameters(),
|
|
90
90
|
tz.m.LBFGS(),
|
|
91
91
|
tz.m.Backtracking()
|
|
@@ -117,7 +117,7 @@ class Backtracking(LineSearchBase):
|
|
|
117
117
|
|
|
118
118
|
# # directional derivative
|
|
119
119
|
if c == 0: d = 0
|
|
120
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.
|
|
120
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), var.get_updates()))
|
|
121
121
|
|
|
122
122
|
# scale init
|
|
123
123
|
init_scale = self.global_state.get('init_scale', 1)
|
|
@@ -199,7 +199,7 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
199
199
|
|
|
200
200
|
# directional derivative (0 if c = 0 because it is not needed)
|
|
201
201
|
if c == 0: d = 0
|
|
202
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.
|
|
202
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), update))
|
|
203
203
|
|
|
204
204
|
# scale beta
|
|
205
205
|
beta = beta * self.global_state['beta_scale']
|