torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -5,10 +5,10 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module,
|
|
8
|
+
from ...core import Chainable, Module, TensorTransform, Transform
|
|
9
9
|
from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
|
|
10
|
-
from ...
|
|
11
|
-
from ..
|
|
10
|
+
from ...linalg import linear_operator
|
|
11
|
+
from ..opt_utils import initial_step_size, safe_clip
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
|
|
@@ -17,7 +17,7 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
|
17
17
|
elif state[key].shape != value.shape: state[key] = value
|
|
18
18
|
else: state[key].lerp_(value, 1-beta)
|
|
19
19
|
|
|
20
|
-
class HessianUpdateStrategy(
|
|
20
|
+
class HessianUpdateStrategy(TensorTransform, ABC):
|
|
21
21
|
"""Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
|
|
22
22
|
|
|
23
23
|
This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
|
|
@@ -106,11 +106,12 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
106
106
|
scale_first: bool = False,
|
|
107
107
|
concat_params: bool = True,
|
|
108
108
|
inverse: bool = True,
|
|
109
|
+
uses_loss: bool = False,
|
|
109
110
|
inner: Chainable | None = None,
|
|
110
111
|
):
|
|
111
112
|
if defaults is None: defaults = {}
|
|
112
113
|
safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
|
|
113
|
-
super().__init__(defaults,
|
|
114
|
+
super().__init__(defaults, uses_loss=uses_loss, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
114
115
|
|
|
115
116
|
def reset_for_online(self):
|
|
116
117
|
super().reset_for_online()
|
|
@@ -141,23 +142,27 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
141
142
|
return H
|
|
142
143
|
|
|
143
144
|
# ------------------------------ common methods ------------------------------ #
|
|
144
|
-
def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
|
|
145
|
+
def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float | None:
|
|
145
146
|
"""returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
|
|
146
147
|
ys = y.dot(s)
|
|
147
148
|
yy = y.dot(y)
|
|
148
|
-
|
|
149
|
-
return
|
|
149
|
+
tiny = torch.finfo(ys.dtype).tiny * 2
|
|
150
|
+
if ys > tiny and yy > tiny: return yy/ys
|
|
151
|
+
return None
|
|
150
152
|
|
|
151
|
-
def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
|
|
153
|
+
def reset_P(self, P: torch.Tensor, s:torch.Tensor, y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
|
|
152
154
|
"""resets ``P`` which is either B or H"""
|
|
153
155
|
set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
|
|
154
|
-
if init_scale == 'auto':
|
|
155
|
-
|
|
156
|
+
if init_scale == 'auto':
|
|
157
|
+
init_scale = self.auto_initial_scale(s,y)
|
|
158
|
+
state["scaled"] = init_scale is not None
|
|
159
|
+
|
|
160
|
+
if init_scale is not None and init_scale != 1:
|
|
156
161
|
if inverse: P /= init_scale
|
|
157
162
|
else: P *= init_scale
|
|
158
163
|
|
|
159
164
|
@torch.no_grad
|
|
160
|
-
def
|
|
165
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
161
166
|
p = param.view(-1); g = tensor.view(-1)
|
|
162
167
|
inverse = setting['inverse']
|
|
163
168
|
M_key = 'H' if inverse else 'B'
|
|
@@ -182,6 +187,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
182
187
|
state['f_prev'] = loss
|
|
183
188
|
state['p_prev'] = p.clone()
|
|
184
189
|
state['g_prev'] = g.clone()
|
|
190
|
+
state["scaled"] = False
|
|
185
191
|
return
|
|
186
192
|
|
|
187
193
|
state['f'] = loss
|
|
@@ -205,9 +211,13 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
205
211
|
if gtol is not None and y.abs().max() <= gtol:
|
|
206
212
|
return
|
|
207
213
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
214
|
+
# apply automatic initial scale if it hasn't been applied
|
|
215
|
+
if (not state["scaled"]) and (init_scale == 'auto'):
|
|
216
|
+
scale = self.auto_initial_scale(s,y)
|
|
217
|
+
if scale is not None:
|
|
218
|
+
state["scaled"] = True
|
|
219
|
+
if inverse: M /= self.auto_initial_scale(s,y)
|
|
220
|
+
else: M *= self.auto_initial_scale(s,y)
|
|
211
221
|
|
|
212
222
|
beta = setting['beta']
|
|
213
223
|
if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
|
|
@@ -223,7 +233,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
223
233
|
state['f_prev'] = loss
|
|
224
234
|
|
|
225
235
|
@torch.no_grad
|
|
226
|
-
def
|
|
236
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
227
237
|
step = state['step']
|
|
228
238
|
|
|
229
239
|
if setting['scale_first'] and step == 1:
|
|
@@ -250,8 +260,8 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
250
260
|
self.global_state.clear()
|
|
251
261
|
return tensor.mul_(initial_step_size(tensor))
|
|
252
262
|
|
|
253
|
-
def get_H(self,
|
|
254
|
-
param =
|
|
263
|
+
def get_H(self, objective):
|
|
264
|
+
param = objective.params[0]
|
|
255
265
|
state = self.state[param]
|
|
256
266
|
settings = self.settings[param]
|
|
257
267
|
if "B" in state:
|
|
@@ -367,22 +377,21 @@ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
367
377
|
B += term1.sub_(term2)
|
|
368
378
|
return B
|
|
369
379
|
|
|
370
|
-
|
|
380
|
+
|
|
381
|
+
def bfgs_H_(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol: float):
|
|
371
382
|
sy = s.dot(y)
|
|
372
383
|
if sy <= tol: return H
|
|
373
384
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
Hy = H@y
|
|
377
|
-
scale1 = (sy + y.dot(Hy)) / sy_sq
|
|
378
|
-
term1 = s.outer(s).mul_(scale1)
|
|
385
|
+
rho = 1.0 / sy
|
|
386
|
+
Hy = H @ y
|
|
379
387
|
|
|
380
|
-
|
|
381
|
-
term2 =
|
|
388
|
+
term1 = (s.outer(s)).mul_(rho * (1 + rho * y.dot(Hy)))
|
|
389
|
+
term2 = (Hy.outer(s) + s.outer(Hy)).mul_(rho)
|
|
382
390
|
|
|
383
|
-
H
|
|
391
|
+
H.add_(term1).sub_(term2)
|
|
384
392
|
return H
|
|
385
393
|
|
|
394
|
+
|
|
386
395
|
class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
387
396
|
"""Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
|
|
388
397
|
|
|
@@ -428,7 +437,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
|
428
437
|
BFGS with backtracking line search:
|
|
429
438
|
|
|
430
439
|
```python
|
|
431
|
-
opt = tz.
|
|
440
|
+
opt = tz.Optimizer(
|
|
432
441
|
model.parameters(),
|
|
433
442
|
tz.m.BFGS(),
|
|
434
443
|
tz.m.Backtracking()
|
|
@@ -437,7 +446,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
|
437
446
|
|
|
438
447
|
BFGS with trust region
|
|
439
448
|
```python
|
|
440
|
-
opt = tz.
|
|
449
|
+
opt = tz.Optimizer(
|
|
441
450
|
model.parameters(),
|
|
442
451
|
tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
|
|
443
452
|
)
|
|
@@ -505,7 +514,7 @@ class SR1(_InverseHessianUpdateStrategyDefaults):
|
|
|
505
514
|
|
|
506
515
|
SR1 with trust region
|
|
507
516
|
```python
|
|
508
|
-
opt = tz.
|
|
517
|
+
opt = tz.Optimizer(
|
|
509
518
|
model.parameters(),
|
|
510
519
|
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
511
520
|
)
|
|
@@ -1005,7 +1014,7 @@ def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
|
1005
1014
|
return g - (y * (s.dot(g) / sy))
|
|
1006
1015
|
|
|
1007
1016
|
|
|
1008
|
-
class GradientCorrection(
|
|
1017
|
+
class GradientCorrection(TensorTransform):
|
|
1009
1018
|
"""
|
|
1010
1019
|
Estimates gradient at minima along search direction assuming function is quadratic.
|
|
1011
1020
|
|
|
@@ -1015,7 +1024,7 @@ class GradientCorrection(Transform):
|
|
|
1015
1024
|
L-BFGS with gradient correction
|
|
1016
1025
|
|
|
1017
1026
|
```python
|
|
1018
|
-
opt = tz.
|
|
1027
|
+
opt = tz.Optimizer(
|
|
1019
1028
|
model.parameters(),
|
|
1020
1029
|
tz.m.LBFGS(inner=tz.m.GradientCorrection()),
|
|
1021
1030
|
tz.m.Backtracking()
|
|
@@ -1027,9 +1036,9 @@ class GradientCorrection(Transform):
|
|
|
1027
1036
|
|
|
1028
1037
|
"""
|
|
1029
1038
|
def __init__(self):
|
|
1030
|
-
super().__init__(
|
|
1039
|
+
super().__init__()
|
|
1031
1040
|
|
|
1032
|
-
def
|
|
1041
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
1033
1042
|
if 'p_prev' not in states[0]:
|
|
1034
1043
|
p_prev = unpack_states(states, tensors, 'p_prev', init=params)
|
|
1035
1044
|
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
|
|
@@ -1154,6 +1163,7 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
1154
1163
|
scale_first=scale_first,
|
|
1155
1164
|
concat_params=concat_params,
|
|
1156
1165
|
inverse=True,
|
|
1166
|
+
uses_loss=True,
|
|
1157
1167
|
inner=inner,
|
|
1158
1168
|
)
|
|
1159
1169
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
@@ -1171,13 +1181,18 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
1171
1181
|
|
|
1172
1182
|
# this is supposed to be equivalent (and it is)
|
|
1173
1183
|
def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
term =
|
|
1178
|
-
H.sub_(term, alpha=1-alpha**2)
|
|
1184
|
+
Hy = H @ y
|
|
1185
|
+
yHy = safe_clip(y.dot(Hy))
|
|
1186
|
+
term = Hy.outer(Hy).div_(yHy)
|
|
1187
|
+
H.sub_(term, alpha=(1-alpha**2))
|
|
1179
1188
|
return H
|
|
1180
1189
|
|
|
1190
|
+
# def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
|
|
1191
|
+
# Hy = H @ y
|
|
1192
|
+
# yHy = safe_clip(y.dot(Hy))
|
|
1193
|
+
# H -= (Hy.outer(y) @ H).div_(yHy)
|
|
1194
|
+
# return H
|
|
1195
|
+
|
|
1181
1196
|
class ShorR(HessianUpdateStrategy):
|
|
1182
1197
|
"""Shor’s r-algorithm.
|
|
1183
1198
|
|
|
@@ -1,29 +1,39 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
4
|
-
from ...utils import TensorList,
|
|
5
|
-
from
|
|
3
|
+
from ...core import Chainable, Transform
|
|
4
|
+
from ...utils import TensorList, unpack_dicts, unpack_states, vec_to_tensors_
|
|
5
|
+
from ...linalg.linear_operator import Dense
|
|
6
|
+
|
|
6
7
|
|
|
7
8
|
def sg2_(
|
|
8
9
|
delta_g: torch.Tensor,
|
|
9
10
|
cd: torch.Tensor,
|
|
10
11
|
) -> torch.Tensor:
|
|
11
|
-
"""cd is c * perturbation
|
|
12
|
-
(or divide delta_g by two)."""
|
|
12
|
+
"""cd is c * perturbation."""
|
|
13
13
|
|
|
14
|
-
M = torch.outer(
|
|
14
|
+
M = torch.outer(0.5 / cd, delta_g)
|
|
15
15
|
H_hat = 0.5 * (M + M.T)
|
|
16
16
|
|
|
17
17
|
return H_hat
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class SG2(
|
|
21
|
+
class SG2(Transform):
|
|
22
22
|
"""second-order stochastic gradient
|
|
23
23
|
|
|
24
|
+
2SPSA (second-order SPSA)
|
|
25
|
+
```python
|
|
26
|
+
opt = tz.Optimizer(
|
|
27
|
+
model.parameters(),
|
|
28
|
+
tz.m.SPSA(),
|
|
29
|
+
tz.m.SG2(),
|
|
30
|
+
tz.m.LR(1e-2),
|
|
31
|
+
)
|
|
32
|
+
```
|
|
33
|
+
|
|
24
34
|
SG2 with line search
|
|
25
35
|
```python
|
|
26
|
-
opt = tz.
|
|
36
|
+
opt = tz.Optimizer(
|
|
27
37
|
model.parameters(),
|
|
28
38
|
tz.m.SG2(),
|
|
29
39
|
tz.m.Backtracking()
|
|
@@ -32,9 +42,9 @@ class SG2(Module):
|
|
|
32
42
|
|
|
33
43
|
SG2 with trust region
|
|
34
44
|
```python
|
|
35
|
-
opt = tz.
|
|
45
|
+
opt = tz.Optimizer(
|
|
36
46
|
model.parameters(),
|
|
37
|
-
tz.m.LevenbergMarquardt(tz.m.SG2()),
|
|
47
|
+
tz.m.LevenbergMarquardt(tz.m.SG2(beta=0.75. n_samples=4)),
|
|
38
48
|
)
|
|
39
49
|
```
|
|
40
50
|
|
|
@@ -43,61 +53,51 @@ class SG2(Module):
|
|
|
43
53
|
def __init__(
|
|
44
54
|
self,
|
|
45
55
|
n_samples: int = 1,
|
|
46
|
-
|
|
56
|
+
n_first_step_samples: int = 10,
|
|
57
|
+
start_step: int = 10,
|
|
47
58
|
beta: float | None = None,
|
|
48
|
-
damping: float =
|
|
49
|
-
|
|
50
|
-
one_sided: bool = False, # one-sided hessian
|
|
51
|
-
use_lstsq: bool = True,
|
|
59
|
+
damping: float = 1e-4,
|
|
60
|
+
h: float = 1e-2,
|
|
52
61
|
seed=None,
|
|
62
|
+
update_freq: int = 1,
|
|
53
63
|
inner: Chainable | None = None,
|
|
54
64
|
):
|
|
55
|
-
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping,
|
|
56
|
-
super().__init__(defaults)
|
|
57
|
-
|
|
58
|
-
if inner is not None: self.set_child('inner', inner)
|
|
65
|
+
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, seed=seed, start_step=start_step, n_first_step_samples=n_first_step_samples)
|
|
66
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
59
67
|
|
|
60
68
|
@torch.no_grad
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
self.
|
|
69
|
+
def update_states(self, objective, states, settings):
|
|
70
|
+
fs = settings[0]
|
|
71
|
+
k = self.increment_counter("step", 0)
|
|
64
72
|
|
|
65
|
-
params = TensorList(
|
|
66
|
-
closure =
|
|
73
|
+
params = TensorList(objective.params)
|
|
74
|
+
closure = objective.closure
|
|
67
75
|
if closure is None:
|
|
68
76
|
raise RuntimeError("closure is required for SG2")
|
|
69
77
|
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
70
78
|
|
|
71
|
-
h =
|
|
79
|
+
h = unpack_dicts(settings, "h")
|
|
72
80
|
x_0 = params.clone()
|
|
73
|
-
n_samples =
|
|
81
|
+
n_samples = fs["n_samples"]
|
|
82
|
+
if k == 0: n_samples = fs["n_first_step_samples"]
|
|
74
83
|
H_hat = None
|
|
75
84
|
|
|
85
|
+
# compute new approximation
|
|
76
86
|
for i in range(n_samples):
|
|
77
87
|
# generate perturbation
|
|
78
88
|
cd = params.rademacher_like(generator=generator).mul_(h)
|
|
79
89
|
|
|
80
|
-
#
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
closure()
|
|
90
|
+
# two sided hessian approximation
|
|
91
|
+
params.add_(cd)
|
|
92
|
+
closure()
|
|
93
|
+
g_p = params.grad.fill_none_(params)
|
|
85
94
|
|
|
86
|
-
|
|
87
|
-
|
|
95
|
+
params.copy_(x_0)
|
|
96
|
+
params.sub_(cd)
|
|
97
|
+
closure()
|
|
98
|
+
g_n = params.grad.fill_none_(params)
|
|
88
99
|
|
|
89
|
-
|
|
90
|
-
else:
|
|
91
|
-
params.add_(cd)
|
|
92
|
-
closure()
|
|
93
|
-
g_p = params.grad.fill_none_(params)
|
|
94
|
-
|
|
95
|
-
params.copy_(x_0)
|
|
96
|
-
params.sub_(cd)
|
|
97
|
-
closure()
|
|
98
|
-
g_n = params.grad.fill_none_(params)
|
|
99
|
-
|
|
100
|
-
delta_g = g_p - g_n
|
|
100
|
+
delta_g = g_p - g_n
|
|
101
101
|
|
|
102
102
|
# restore params
|
|
103
103
|
params.set_(x_0)
|
|
@@ -114,179 +114,43 @@ class SG2(Module):
|
|
|
114
114
|
assert H_hat is not None
|
|
115
115
|
if n_samples > 1: H_hat /= n_samples
|
|
116
116
|
|
|
117
|
+
# add damping
|
|
118
|
+
if fs["damping"] != 0:
|
|
119
|
+
reg = torch.eye(H_hat.size(0), device=H_hat.device, dtype=H_hat.dtype).mul_(fs["damping"])
|
|
120
|
+
H_hat += reg
|
|
121
|
+
|
|
117
122
|
# update H
|
|
118
123
|
H = self.global_state.get("H", None)
|
|
119
124
|
if H is None: H = H_hat
|
|
120
125
|
else:
|
|
121
|
-
beta =
|
|
122
|
-
if beta is None: beta = k / (k+
|
|
126
|
+
beta = fs["beta"]
|
|
127
|
+
if beta is None: beta = (k+1) / (k+2)
|
|
123
128
|
H.lerp_(H_hat, 1-beta)
|
|
124
129
|
|
|
125
130
|
self.global_state["H"] = H
|
|
126
131
|
|
|
127
132
|
|
|
128
133
|
@torch.no_grad
|
|
129
|
-
def
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
H = self.global_state["H"],
|
|
133
|
-
damping = self.defaults["damping"],
|
|
134
|
-
inner = self.children.get("inner", None),
|
|
135
|
-
H_tfm=None,
|
|
136
|
-
eigval_fn=self.defaults["eigval_fn"],
|
|
137
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
138
|
-
g_proj=None,
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
var.update = vec_to_tensors(dir, var.params)
|
|
142
|
-
return var
|
|
143
|
-
|
|
144
|
-
def get_H(self,var=...):
|
|
145
|
-
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
# two sided
|
|
151
|
-
# we have g via x + d, x - d
|
|
152
|
-
# H via g(x + d), g(x - d)
|
|
153
|
-
# 1 is x, x+2d
|
|
154
|
-
# 2 is x, x-2d
|
|
155
|
-
# 5 evals in total
|
|
156
|
-
|
|
157
|
-
# one sided
|
|
158
|
-
# g via x, x + d
|
|
159
|
-
# 1 is x, x + d
|
|
160
|
-
# 2 is x, x - d
|
|
161
|
-
# 3 evals and can use two sided for g_0
|
|
162
|
-
|
|
163
|
-
class SPSA2(Module):
|
|
164
|
-
"""second-order SPSA
|
|
165
|
-
|
|
166
|
-
SPSA2 with line search
|
|
167
|
-
```python
|
|
168
|
-
opt = tz.Modular(
|
|
169
|
-
model.parameters(),
|
|
170
|
-
tz.m.SPSA2(),
|
|
171
|
-
tz.m.Backtracking()
|
|
172
|
-
)
|
|
173
|
-
```
|
|
174
|
-
|
|
175
|
-
SPSA2 with trust region
|
|
176
|
-
```python
|
|
177
|
-
opt = tz.Modular(
|
|
178
|
-
model.parameters(),
|
|
179
|
-
tz.m.LevenbergMarquardt(tz.m.SPSA2()),
|
|
180
|
-
)
|
|
181
|
-
```
|
|
182
|
-
"""
|
|
183
|
-
|
|
184
|
-
def __init__(
|
|
185
|
-
self,
|
|
186
|
-
n_samples: int = 1,
|
|
187
|
-
h: float = 1e-2,
|
|
188
|
-
beta: float | None = None,
|
|
189
|
-
damping: float = 0,
|
|
190
|
-
eigval_fn=None,
|
|
191
|
-
use_lstsq: bool = True,
|
|
192
|
-
seed=None,
|
|
193
|
-
inner: Chainable | None = None,
|
|
194
|
-
):
|
|
195
|
-
defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
|
|
196
|
-
super().__init__(defaults)
|
|
197
|
-
|
|
198
|
-
if inner is not None: self.set_child('inner', inner)
|
|
199
|
-
|
|
200
|
-
@torch.no_grad
|
|
201
|
-
def update(self, var):
|
|
202
|
-
k = self.global_state.get('step', 0) + 1
|
|
203
|
-
self.global_state["step"] = k
|
|
134
|
+
def apply_states(self, objective, states, settings):
|
|
135
|
+
fs = settings[0]
|
|
136
|
+
updates = objective.get_updates()
|
|
204
137
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
if
|
|
208
|
-
|
|
138
|
+
H: torch.Tensor = self.global_state["H"]
|
|
139
|
+
k = self.global_state["step"]
|
|
140
|
+
if k < fs["start_step"]:
|
|
141
|
+
# don't precondition yet
|
|
142
|
+
# I guess we can try using trace to scale the update
|
|
143
|
+
# because it will have horrible scaling otherwise
|
|
144
|
+
torch._foreach_div_(updates, H.trace())
|
|
145
|
+
return objective
|
|
209
146
|
|
|
210
|
-
|
|
147
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
148
|
+
sol = torch.linalg.lstsq(H, b).solution # pylint:disable=not-callable
|
|
211
149
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
n_samples = self.defaults["n_samples"]
|
|
215
|
-
H_hat = None
|
|
216
|
-
g_0 = None
|
|
217
|
-
|
|
218
|
-
for i in range(n_samples):
|
|
219
|
-
# perturbations for g and H
|
|
220
|
-
cd_g = params.rademacher_like(generator=generator).mul_(h)
|
|
221
|
-
cd_H = params.rademacher_like(generator=generator).mul_(h)
|
|
222
|
-
|
|
223
|
-
# evaluate 4 points
|
|
224
|
-
x_p = x_0 + cd_g
|
|
225
|
-
x_n = x_0 - cd_g
|
|
150
|
+
vec_to_tensors_(sol, updates)
|
|
151
|
+
return objective
|
|
226
152
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
params.add_(cd_H)
|
|
230
|
-
f_pp = closure(False)
|
|
153
|
+
def get_H(self, objective=...):
|
|
154
|
+
return Dense(self.global_state["H"])
|
|
231
155
|
|
|
232
|
-
params.set_(x_n)
|
|
233
|
-
f_n = closure(False)
|
|
234
|
-
params.add_(cd_H)
|
|
235
|
-
f_np = closure(False)
|
|
236
|
-
|
|
237
|
-
g_p_vec = (f_pp - f_p) / cd_H
|
|
238
|
-
g_n_vec = (f_np - f_n) / cd_H
|
|
239
|
-
delta_g = g_p_vec - g_n_vec
|
|
240
|
-
|
|
241
|
-
# restore params
|
|
242
|
-
params.set_(x_0)
|
|
243
156
|
|
|
244
|
-
# compute grad
|
|
245
|
-
g_i = (f_p - f_n) / (2 * cd_g)
|
|
246
|
-
if g_0 is None: g_0 = g_i
|
|
247
|
-
else: g_0 += g_i
|
|
248
|
-
|
|
249
|
-
# compute H hat
|
|
250
|
-
H_i = sg2_(
|
|
251
|
-
delta_g = delta_g.to_vec().div_(2.0),
|
|
252
|
-
cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
|
|
253
|
-
)
|
|
254
|
-
if H_hat is None: H_hat = H_i
|
|
255
|
-
else: H_hat += H_i
|
|
256
|
-
|
|
257
|
-
assert g_0 is not None and H_hat is not None
|
|
258
|
-
if n_samples > 1:
|
|
259
|
-
g_0 /= n_samples
|
|
260
|
-
H_hat /= n_samples
|
|
261
|
-
|
|
262
|
-
# set grad to approximated grad
|
|
263
|
-
var.grad = g_0
|
|
264
|
-
|
|
265
|
-
# update H
|
|
266
|
-
H = self.global_state.get("H", None)
|
|
267
|
-
if H is None: H = H_hat
|
|
268
|
-
else:
|
|
269
|
-
beta = self.defaults["beta"]
|
|
270
|
-
if beta is None: beta = k / (k+1)
|
|
271
|
-
H.lerp_(H_hat, 1-beta)
|
|
272
|
-
|
|
273
|
-
self.global_state["H"] = H
|
|
274
|
-
|
|
275
|
-
@torch.no_grad
|
|
276
|
-
def apply(self, var):
|
|
277
|
-
dir = _newton_step(
|
|
278
|
-
var=var,
|
|
279
|
-
H = self.global_state["H"],
|
|
280
|
-
damping = self.defaults["damping"],
|
|
281
|
-
inner = self.children.get("inner", None),
|
|
282
|
-
H_tfm=None,
|
|
283
|
-
eigval_fn=self.defaults["eigval_fn"],
|
|
284
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
285
|
-
g_proj=None,
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
var.update = vec_to_tensors(dir, var.params)
|
|
289
|
-
return var
|
|
290
|
-
|
|
291
|
-
def get_H(self,var=...):
|
|
292
|
-
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|