torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import torch
|
|
3
3
|
import random
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
5
|
+
from ...core import TensorTransform
|
|
6
6
|
from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
|
|
7
7
|
|
|
8
8
|
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
@@ -12,24 +12,24 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
|
12
12
|
return tensors * lr
|
|
13
13
|
return tensors
|
|
14
14
|
|
|
15
|
-
class LR(
|
|
15
|
+
class LR(TensorTransform):
|
|
16
16
|
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
17
17
|
def __init__(self, lr: float):
|
|
18
18
|
defaults=dict(lr=lr)
|
|
19
|
-
super().__init__(defaults
|
|
19
|
+
super().__init__(defaults)
|
|
20
20
|
|
|
21
21
|
@torch.no_grad
|
|
22
|
-
def
|
|
22
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
23
23
|
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
24
24
|
|
|
25
|
-
class StepSize(
|
|
25
|
+
class StepSize(TensorTransform):
|
|
26
26
|
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
27
27
|
def __init__(self, step_size: float, key = 'step_size'):
|
|
28
28
|
defaults={"key": key, key: step_size}
|
|
29
|
-
super().__init__(defaults
|
|
29
|
+
super().__init__(defaults)
|
|
30
30
|
|
|
31
31
|
@torch.no_grad
|
|
32
|
-
def
|
|
32
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
33
33
|
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
34
34
|
|
|
35
35
|
|
|
@@ -38,8 +38,8 @@ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberLi
|
|
|
38
38
|
if step > steps: return end_lr
|
|
39
39
|
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
40
40
|
|
|
41
|
-
class Warmup(
|
|
42
|
-
"""Learning rate warmup, linearly increases learning rate multiplier from
|
|
41
|
+
class Warmup(TensorTransform):
|
|
42
|
+
"""Learning rate warmup, linearly increases learning rate multiplier from ``start_lr`` to ``end_lr`` over ``steps`` steps.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
@@ -51,7 +51,7 @@ class Warmup(Transform):
|
|
|
51
51
|
|
|
52
52
|
.. code-block:: python
|
|
53
53
|
|
|
54
|
-
opt = tz.
|
|
54
|
+
opt = tz.Optimizer(
|
|
55
55
|
model.parameters(),
|
|
56
56
|
tz.m.Adam(),
|
|
57
57
|
tz.m.LR(1e-2),
|
|
@@ -64,7 +64,7 @@ class Warmup(Transform):
|
|
|
64
64
|
super().__init__(defaults, uses_grad=False)
|
|
65
65
|
|
|
66
66
|
@torch.no_grad
|
|
67
|
-
def
|
|
67
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
68
68
|
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
69
69
|
num_steps = settings[0]['steps']
|
|
70
70
|
step = self.global_state.get('step', 0)
|
|
@@ -77,7 +77,7 @@ class Warmup(Transform):
|
|
|
77
77
|
self.global_state['step'] = step + 1
|
|
78
78
|
return tensors
|
|
79
79
|
|
|
80
|
-
class WarmupNormClip(
|
|
80
|
+
class WarmupNormClip(TensorTransform):
|
|
81
81
|
"""Warmup via clipping of the update norm.
|
|
82
82
|
|
|
83
83
|
Args:
|
|
@@ -90,7 +90,7 @@ class WarmupNormClip(Transform):
|
|
|
90
90
|
|
|
91
91
|
.. code-block:: python
|
|
92
92
|
|
|
93
|
-
opt = tz.
|
|
93
|
+
opt = tz.Optimizer(
|
|
94
94
|
model.parameters(),
|
|
95
95
|
tz.m.Adam(),
|
|
96
96
|
tz.m.WarmupNormClip(steps=1000)
|
|
@@ -102,7 +102,7 @@ class WarmupNormClip(Transform):
|
|
|
102
102
|
super().__init__(defaults, uses_grad=False)
|
|
103
103
|
|
|
104
104
|
@torch.no_grad
|
|
105
|
-
def
|
|
105
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
106
106
|
start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
|
|
107
107
|
num_steps = settings[0]['steps']
|
|
108
108
|
step = self.global_state.get('step', 0)
|
|
@@ -118,8 +118,8 @@ class WarmupNormClip(Transform):
|
|
|
118
118
|
return tensors
|
|
119
119
|
|
|
120
120
|
|
|
121
|
-
class RandomStepSize(
|
|
122
|
-
"""Uses random global or layer-wise step size from
|
|
121
|
+
class RandomStepSize(TensorTransform):
|
|
122
|
+
"""Uses random global or layer-wise step size from ``low`` to ``high``.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
125
|
low (float, optional): minimum learning rate. Defaults to 0.
|
|
@@ -133,7 +133,7 @@ class RandomStepSize(Transform):
|
|
|
133
133
|
super().__init__(defaults, uses_grad=False)
|
|
134
134
|
|
|
135
135
|
@torch.no_grad
|
|
136
|
-
def
|
|
136
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
137
137
|
s = settings[0]
|
|
138
138
|
parameterwise = s['parameterwise']
|
|
139
139
|
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import cast
|
|
4
|
+
from typing import cast, final
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Objective
|
|
9
9
|
from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
|
|
10
10
|
|
|
11
11
|
|
|
@@ -16,14 +16,15 @@ class TerminationCriteriaBase(Module):
|
|
|
16
16
|
super().__init__(defaults)
|
|
17
17
|
|
|
18
18
|
@abstractmethod
|
|
19
|
-
def termination_criteria(self,
|
|
19
|
+
def termination_criteria(self, objective: Objective) -> bool:
|
|
20
20
|
...
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
@final
|
|
23
|
+
def should_terminate(self, objective: Objective) -> bool:
|
|
23
24
|
n_bad = self.global_state.get('_n_bad', 0)
|
|
24
25
|
n = self.defaults['_n']
|
|
25
26
|
|
|
26
|
-
if self.termination_criteria(
|
|
27
|
+
if self.termination_criteria(objective):
|
|
27
28
|
n_bad += 1
|
|
28
29
|
if n_bad >= n:
|
|
29
30
|
self.global_state['_n_bad'] = 0
|
|
@@ -36,12 +37,12 @@ class TerminationCriteriaBase(Module):
|
|
|
36
37
|
return False
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
def update(self,
|
|
40
|
-
|
|
41
|
-
if
|
|
40
|
+
def update(self, objective):
|
|
41
|
+
objective.should_terminate = self.should_terminate(objective)
|
|
42
|
+
if objective.should_terminate: self.global_state['_n_bad'] = 0
|
|
42
43
|
|
|
43
|
-
def apply(self,
|
|
44
|
-
return
|
|
44
|
+
def apply(self, objective):
|
|
45
|
+
return objective
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
class TerminateAfterNSteps(TerminationCriteriaBase):
|
|
@@ -49,7 +50,7 @@ class TerminateAfterNSteps(TerminationCriteriaBase):
|
|
|
49
50
|
defaults = dict(steps=steps)
|
|
50
51
|
super().__init__(defaults)
|
|
51
52
|
|
|
52
|
-
def termination_criteria(self,
|
|
53
|
+
def termination_criteria(self, objective):
|
|
53
54
|
step = self.global_state.get('step', 0)
|
|
54
55
|
self.global_state['step'] = step + 1
|
|
55
56
|
|
|
@@ -61,16 +62,17 @@ class TerminateAfterNEvaluations(TerminationCriteriaBase):
|
|
|
61
62
|
defaults = dict(maxevals=maxevals)
|
|
62
63
|
super().__init__(defaults)
|
|
63
64
|
|
|
64
|
-
def termination_criteria(self,
|
|
65
|
+
def termination_criteria(self, objective):
|
|
65
66
|
maxevals = self.defaults['maxevals']
|
|
66
|
-
|
|
67
|
+
assert objective.modular is not None
|
|
68
|
+
return objective.modular.num_evaluations >= maxevals
|
|
67
69
|
|
|
68
70
|
class TerminateAfterNSeconds(TerminationCriteriaBase):
|
|
69
71
|
def __init__(self, seconds:float, sec_fn = time.time):
|
|
70
72
|
defaults = dict(seconds=seconds, sec_fn=sec_fn)
|
|
71
73
|
super().__init__(defaults)
|
|
72
74
|
|
|
73
|
-
def termination_criteria(self,
|
|
75
|
+
def termination_criteria(self, objective):
|
|
74
76
|
max_seconds = self.defaults['seconds']
|
|
75
77
|
sec_fn = self.defaults['sec_fn']
|
|
76
78
|
|
|
@@ -88,10 +90,10 @@ class TerminateByGradientNorm(TerminationCriteriaBase):
|
|
|
88
90
|
defaults = dict(tol=tol, ord=ord)
|
|
89
91
|
super().__init__(defaults, n=n)
|
|
90
92
|
|
|
91
|
-
def termination_criteria(self,
|
|
93
|
+
def termination_criteria(self, objective):
|
|
92
94
|
tol = self.defaults['tol']
|
|
93
95
|
ord = self.defaults['ord']
|
|
94
|
-
return TensorList(
|
|
96
|
+
return TensorList(objective.get_grads()).global_metric(ord) <= tol
|
|
95
97
|
|
|
96
98
|
|
|
97
99
|
class TerminateByUpdateNorm(TerminationCriteriaBase):
|
|
@@ -100,20 +102,20 @@ class TerminateByUpdateNorm(TerminationCriteriaBase):
|
|
|
100
102
|
defaults = dict(tol=tol, ord=ord)
|
|
101
103
|
super().__init__(defaults, n=n)
|
|
102
104
|
|
|
103
|
-
def termination_criteria(self,
|
|
105
|
+
def termination_criteria(self, objective):
|
|
104
106
|
step = self.global_state.get('step', 0)
|
|
105
107
|
self.global_state['step'] = step + 1
|
|
106
108
|
|
|
107
109
|
tol = self.defaults['tol']
|
|
108
110
|
ord = self.defaults['ord']
|
|
109
111
|
|
|
110
|
-
p_prev = self.get_state(
|
|
112
|
+
p_prev = self.get_state(objective.params, 'p_prev', cls=TensorList)
|
|
111
113
|
if step == 0:
|
|
112
|
-
p_prev.copy_(
|
|
114
|
+
p_prev.copy_(objective.params)
|
|
113
115
|
return False
|
|
114
116
|
|
|
115
|
-
should_terminate = (p_prev -
|
|
116
|
-
p_prev.copy_(
|
|
117
|
+
should_terminate = (p_prev - objective.params).global_metric(ord) <= tol
|
|
118
|
+
p_prev.copy_(objective.params)
|
|
117
119
|
return should_terminate
|
|
118
120
|
|
|
119
121
|
|
|
@@ -122,10 +124,10 @@ class TerminateOnNoImprovement(TerminationCriteriaBase):
|
|
|
122
124
|
defaults = dict(tol=tol)
|
|
123
125
|
super().__init__(defaults, n=n)
|
|
124
126
|
|
|
125
|
-
def termination_criteria(self,
|
|
127
|
+
def termination_criteria(self, objective):
|
|
126
128
|
tol = self.defaults['tol']
|
|
127
129
|
|
|
128
|
-
f = tofloat(
|
|
130
|
+
f = tofloat(objective.get_loss(False))
|
|
129
131
|
if 'f_min' not in self.global_state:
|
|
130
132
|
self.global_state['f_min'] = f
|
|
131
133
|
return False
|
|
@@ -141,9 +143,9 @@ class TerminateOnLossReached(TerminationCriteriaBase):
|
|
|
141
143
|
defaults = dict(value=value)
|
|
142
144
|
super().__init__(defaults)
|
|
143
145
|
|
|
144
|
-
def termination_criteria(self,
|
|
146
|
+
def termination_criteria(self, objective):
|
|
145
147
|
value = self.defaults['value']
|
|
146
|
-
return
|
|
148
|
+
return objective.get_loss(False) <= value
|
|
147
149
|
|
|
148
150
|
class TerminateAny(TerminationCriteriaBase):
|
|
149
151
|
def __init__(self, *criteria: TerminationCriteriaBase):
|
|
@@ -151,9 +153,9 @@ class TerminateAny(TerminationCriteriaBase):
|
|
|
151
153
|
|
|
152
154
|
self.set_children_sequence(criteria)
|
|
153
155
|
|
|
154
|
-
def termination_criteria(self,
|
|
156
|
+
def termination_criteria(self, objective: Objective) -> bool:
|
|
155
157
|
for c in self.get_children_sequence():
|
|
156
|
-
if cast(TerminationCriteriaBase, c).termination_criteria(
|
|
158
|
+
if cast(TerminationCriteriaBase, c).termination_criteria(objective): return True
|
|
157
159
|
|
|
158
160
|
return False
|
|
159
161
|
|
|
@@ -163,9 +165,9 @@ class TerminateAll(TerminationCriteriaBase):
|
|
|
163
165
|
|
|
164
166
|
self.set_children_sequence(criteria)
|
|
165
167
|
|
|
166
|
-
def termination_criteria(self,
|
|
168
|
+
def termination_criteria(self, objective: Objective) -> bool:
|
|
167
169
|
for c in self.get_children_sequence():
|
|
168
|
-
if not cast(TerminationCriteriaBase, c).termination_criteria(
|
|
170
|
+
if not cast(TerminationCriteriaBase, c).termination_criteria(objective): return False
|
|
169
171
|
|
|
170
172
|
return True
|
|
171
173
|
|
|
@@ -173,7 +175,7 @@ class TerminateNever(TerminationCriteriaBase):
|
|
|
173
175
|
def __init__(self):
|
|
174
176
|
super().__init__()
|
|
175
177
|
|
|
176
|
-
def termination_criteria(self,
|
|
178
|
+
def termination_criteria(self, objective): return False
|
|
177
179
|
|
|
178
180
|
def make_termination_criteria(
|
|
179
181
|
ftol: float | None = None,
|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
from ...core import Chainable, Module
|
|
7
7
|
from ...utils import TensorList, vec_to_tensors
|
|
8
|
-
from ...
|
|
8
|
+
from ...linalg.linear_operator import LinearOperator
|
|
9
9
|
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
10
10
|
|
|
11
11
|
|
|
@@ -58,7 +58,7 @@ def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_param
|
|
|
58
58
|
for _ in range(it_max):
|
|
59
59
|
r_try = (r_min + r_max) / 2
|
|
60
60
|
lam = r_try * M
|
|
61
|
-
s_lam = H.
|
|
61
|
+
s_lam = H.solve_plus_diag(g, lam).neg()
|
|
62
62
|
# s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
|
|
63
63
|
solver_it += 1
|
|
64
64
|
crit = conv_criterion(s_lam, r_try)
|
|
@@ -109,7 +109,7 @@ class CubicRegularization(TrustRegionBase):
|
|
|
109
109
|
|
|
110
110
|
.. code-block:: python
|
|
111
111
|
|
|
112
|
-
opt = tz.
|
|
112
|
+
opt = tz.Optimizer(
|
|
113
113
|
model.parameters(),
|
|
114
114
|
tz.m.CubicRegularization(tz.m.Newton()),
|
|
115
115
|
)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
4
|
from ...core import Chainable, Module
|
|
5
|
-
from ...
|
|
5
|
+
from ...linalg import linear_operator
|
|
6
6
|
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
7
7
|
|
|
8
8
|
|
|
@@ -32,38 +32,31 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
32
32
|
max_attempts (max_attempts, optional):
|
|
33
33
|
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
34
34
|
this limit is exceeded. Defaults to 10.
|
|
35
|
+
adaptive (bool, optional):
|
|
36
|
+
if True, trust radius is multiplied by square root of gradient norm.
|
|
35
37
|
fallback (bool, optional):
|
|
36
38
|
if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
|
|
37
39
|
be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
|
|
38
40
|
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
39
41
|
|
|
40
|
-
Examples:
|
|
41
|
-
Gauss-Newton with Levenberg-Marquardt trust-region
|
|
42
|
+
### Examples:
|
|
42
43
|
|
|
43
|
-
|
|
44
|
+
Gauss-Newton with Levenberg-Marquardt trust-region
|
|
44
45
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
```python
|
|
47
|
+
opt = tz.Optimizer(
|
|
48
|
+
model.parameters(),
|
|
49
|
+
tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
|
|
50
|
+
)
|
|
51
|
+
```
|
|
49
52
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
First order trust region (hessian is assumed to be identity)
|
|
60
|
-
|
|
61
|
-
.. code-block:: python
|
|
62
|
-
|
|
63
|
-
opt = tz.Modular(
|
|
64
|
-
model.parameters(),
|
|
65
|
-
tz.m.LevenbergMarquardt(tz.m.Identity()),
|
|
66
|
-
)
|
|
53
|
+
LM-SR1
|
|
54
|
+
```python
|
|
55
|
+
opt = tz.Optimizer(
|
|
56
|
+
model.parameters(),
|
|
57
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
58
|
+
)
|
|
59
|
+
```
|
|
67
60
|
|
|
68
61
|
"""
|
|
69
62
|
def __init__(
|
|
@@ -78,11 +71,12 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
78
71
|
max_attempts: int = 10,
|
|
79
72
|
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
80
73
|
y: float = 0,
|
|
74
|
+
adaptive: bool = False,
|
|
81
75
|
fallback: bool = False,
|
|
82
76
|
update_freq: int = 1,
|
|
83
77
|
inner: Chainable | None = None,
|
|
84
78
|
):
|
|
85
|
-
defaults = dict(y=y, fallback=fallback)
|
|
79
|
+
defaults = dict(y=y, fallback=fallback, adaptive=adaptive)
|
|
86
80
|
super().__init__(
|
|
87
81
|
defaults=defaults,
|
|
88
82
|
hess_module=hess_module,
|
|
@@ -103,6 +97,7 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
103
97
|
|
|
104
98
|
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
105
99
|
y = settings['y']
|
|
100
|
+
adaptive = settings["adaptive"]
|
|
106
101
|
|
|
107
102
|
if isinstance(H, linear_operator.DenseInverse):
|
|
108
103
|
if settings['fallback']:
|
|
@@ -117,12 +112,14 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
117
112
|
)
|
|
118
113
|
|
|
119
114
|
reg = 1/radius
|
|
115
|
+
if adaptive: reg = reg * torch.linalg.vector_norm(g).sqrt()
|
|
116
|
+
|
|
120
117
|
if y == 0:
|
|
121
|
-
return H.
|
|
118
|
+
return H.solve_plus_diag(g, reg) # pyright:ignore[reportAttributeAccessIssue]
|
|
122
119
|
|
|
123
120
|
diag = H.diagonal()
|
|
124
121
|
diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
|
|
125
122
|
if y != 1: diag = (diag*y) + (1-y)
|
|
126
|
-
return H.
|
|
123
|
+
return H.solve_plus_diag(g, diag*reg)
|
|
127
124
|
|
|
128
125
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from ...core import Chainable, Module
|
|
4
|
-
from ...
|
|
4
|
+
from ...linalg import cg, linear_operator
|
|
5
5
|
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
6
6
|
|
|
7
7
|
|
|
@@ -47,7 +47,7 @@ class TrustCG(TrustRegionBase):
|
|
|
47
47
|
|
|
48
48
|
.. code-block:: python
|
|
49
49
|
|
|
50
|
-
opt = tz.
|
|
50
|
+
opt = tz.Optimizer(
|
|
51
51
|
model.parameters(),
|
|
52
52
|
tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
|
|
53
53
|
)
|
|
@@ -7,9 +7,16 @@ from typing import Any, Literal, Protocol, cast, final, overload
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Module,
|
|
11
|
-
from ...
|
|
12
|
-
from ...utils
|
|
10
|
+
from ...core import Chainable, Module, Objective
|
|
11
|
+
from ...linalg.linear_operator import LinearOperator
|
|
12
|
+
from ...utils import (
|
|
13
|
+
TensorList,
|
|
14
|
+
generic_finfo,
|
|
15
|
+
generic_vector_norm,
|
|
16
|
+
safe_dict_update_,
|
|
17
|
+
tofloat,
|
|
18
|
+
vec_to_tensors,
|
|
19
|
+
)
|
|
13
20
|
|
|
14
21
|
|
|
15
22
|
def _flatten_tensors(tensors: list[torch.Tensor]):
|
|
@@ -256,24 +263,24 @@ class TrustRegionBase(Module, ABC):
|
|
|
256
263
|
"""Solve Hx=g with a trust region penalty/bound defined by `radius`"""
|
|
257
264
|
... # pylint:disable=unnecessary-ellipsis
|
|
258
265
|
|
|
259
|
-
def trust_region_update(self,
|
|
266
|
+
def trust_region_update(self, objective: Objective, H: LinearOperator | None) -> None:
|
|
260
267
|
"""updates the state of this module after H or B have been updated, if necessary"""
|
|
261
268
|
|
|
262
|
-
def trust_region_apply(self,
|
|
263
|
-
"""Solves the trust region subproblem and outputs ``
|
|
269
|
+
def trust_region_apply(self, objective: Objective, tensors:list[torch.Tensor], H: LinearOperator | None) -> Objective:
|
|
270
|
+
"""Solves the trust region subproblem and outputs ``Objective`` with the solution direction."""
|
|
264
271
|
assert H is not None
|
|
265
272
|
|
|
266
|
-
params = TensorList(
|
|
273
|
+
params = TensorList(objective.params)
|
|
267
274
|
settings = self.settings[params[0]]
|
|
268
275
|
g = _flatten_tensors(tensors)
|
|
269
276
|
|
|
270
277
|
max_attempts = settings['max_attempts']
|
|
271
278
|
|
|
272
279
|
# loss at x_0
|
|
273
|
-
loss =
|
|
274
|
-
closure =
|
|
280
|
+
loss = objective.loss
|
|
281
|
+
closure = objective.closure
|
|
275
282
|
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
276
|
-
if loss is None: loss =
|
|
283
|
+
if loss is None: loss = objective.get_loss(False)
|
|
277
284
|
loss = tofloat(loss)
|
|
278
285
|
|
|
279
286
|
# trust region step and update
|
|
@@ -313,38 +320,36 @@ class TrustRegionBase(Module, ABC):
|
|
|
313
320
|
)
|
|
314
321
|
|
|
315
322
|
assert d is not None
|
|
316
|
-
if success:
|
|
317
|
-
else:
|
|
323
|
+
if success: objective.updates = vec_to_tensors(d, params)
|
|
324
|
+
else: objective.updates = params.zeros_like()
|
|
318
325
|
|
|
319
|
-
return
|
|
326
|
+
return objective
|
|
320
327
|
|
|
321
328
|
|
|
322
329
|
@final
|
|
323
330
|
@torch.no_grad
|
|
324
|
-
def update(self,
|
|
331
|
+
def update(self, objective):
|
|
325
332
|
step = self.global_state.get('step', 0)
|
|
326
333
|
self.global_state['step'] = step + 1
|
|
327
334
|
|
|
328
335
|
if step % self.defaults["update_freq"] == 0:
|
|
329
336
|
|
|
330
337
|
hessian_module = self.children['hess_module']
|
|
331
|
-
hessian_module.update(
|
|
332
|
-
H = hessian_module.get_H(
|
|
338
|
+
hessian_module.update(objective)
|
|
339
|
+
H = hessian_module.get_H(objective)
|
|
333
340
|
self.global_state["H"] = H
|
|
334
341
|
|
|
335
|
-
self.trust_region_update(
|
|
342
|
+
self.trust_region_update(objective, H=H)
|
|
336
343
|
|
|
337
344
|
|
|
338
345
|
@final
|
|
339
346
|
@torch.no_grad
|
|
340
|
-
def apply(self,
|
|
347
|
+
def apply(self, objective):
|
|
341
348
|
H = self.global_state.get('H', None)
|
|
342
349
|
|
|
343
350
|
# -------------------------------- inner step -------------------------------- #
|
|
344
|
-
|
|
345
|
-
if 'inner' in self.children:
|
|
346
|
-
update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
|
|
351
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
347
352
|
|
|
348
353
|
# ----------------------------------- apply ---------------------------------- #
|
|
349
|
-
return self.trust_region_apply(
|
|
354
|
+
return self.trust_region_apply(objective=objective, tensors=objective.get_updates(), H=H)
|
|
350
355
|
|
|
@@ -3,15 +3,16 @@ from functools import partial
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core
|
|
6
|
+
from ...core import Module, Objective
|
|
7
7
|
from ...utils import tofloat
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def _reset_except_self(
|
|
11
|
-
for m in
|
|
10
|
+
def _reset_except_self(objective: Objective, modules, self: Module):
|
|
11
|
+
for m in modules:
|
|
12
12
|
if m is not self:
|
|
13
13
|
m.reset()
|
|
14
14
|
|
|
15
|
+
|
|
15
16
|
class SVRG(Module):
|
|
16
17
|
"""Stochastic variance reduced gradient method (SVRG).
|
|
17
18
|
|
|
@@ -43,7 +44,7 @@ class SVRG(Module):
|
|
|
43
44
|
## Examples:
|
|
44
45
|
SVRG-LBFGS
|
|
45
46
|
```python
|
|
46
|
-
opt = tz.
|
|
47
|
+
opt = tz.Optimizer(
|
|
47
48
|
model.parameters(),
|
|
48
49
|
tz.m.SVRG(len(dataloader)),
|
|
49
50
|
tz.m.LBFGS(),
|
|
@@ -53,7 +54,7 @@ class SVRG(Module):
|
|
|
53
54
|
|
|
54
55
|
For extra variance reduction one can use Online versions of algorithms, although it won't always help.
|
|
55
56
|
```python
|
|
56
|
-
opt = tz.
|
|
57
|
+
opt = tz.Optimizer(
|
|
57
58
|
model.parameters(),
|
|
58
59
|
tz.m.SVRG(len(dataloader)),
|
|
59
60
|
tz.m.Online(tz.m.LBFGS()),
|
|
@@ -62,7 +63,7 @@ class SVRG(Module):
|
|
|
62
63
|
|
|
63
64
|
Variance reduction can also be applied to gradient estimators.
|
|
64
65
|
```python
|
|
65
|
-
opt = tz.
|
|
66
|
+
opt = tz.Optimizer(
|
|
66
67
|
model.parameters(),
|
|
67
68
|
tz.m.SPSA(),
|
|
68
69
|
tz.m.SVRG(100),
|
|
@@ -71,7 +72,7 @@ class SVRG(Module):
|
|
|
71
72
|
```
|
|
72
73
|
## Notes
|
|
73
74
|
|
|
74
|
-
The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(
|
|
75
|
+
The SVRG gradient is computed as ``g_b(x) - alpha * (g_b(x_0) - g_f(x_0))``, where:
|
|
75
76
|
- ``x`` is current parameters
|
|
76
77
|
- ``x_0`` is initial parameters, where full gradient was computed
|
|
77
78
|
- ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
|
|
@@ -83,17 +84,18 @@ class SVRG(Module):
|
|
|
83
84
|
defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
|
|
84
85
|
super().__init__(defaults)
|
|
85
86
|
|
|
87
|
+
|
|
86
88
|
@torch.no_grad
|
|
87
|
-
def
|
|
88
|
-
params =
|
|
89
|
-
closure =
|
|
89
|
+
def update(self, objective):
|
|
90
|
+
params = objective.params
|
|
91
|
+
closure = objective.closure
|
|
90
92
|
assert closure is not None
|
|
91
93
|
|
|
92
94
|
if "full_grad" not in self.global_state:
|
|
93
95
|
|
|
94
96
|
# -------------------------- calculate full gradient ------------------------- #
|
|
95
|
-
if "full_closure" in
|
|
96
|
-
full_closure =
|
|
97
|
+
if "full_closure" in objective.storage:
|
|
98
|
+
full_closure = objective.storage['full_closure']
|
|
97
99
|
with torch.enable_grad():
|
|
98
100
|
full_loss = full_closure()
|
|
99
101
|
if all(p.grad is None for p in params):
|
|
@@ -116,12 +118,12 @@ class SVRG(Module):
|
|
|
116
118
|
|
|
117
119
|
# accumulate grads
|
|
118
120
|
accumulator = self.get_state(params, 'accumulator')
|
|
119
|
-
grad =
|
|
121
|
+
grad = objective.get_grads()
|
|
120
122
|
torch._foreach_add_(accumulator, grad)
|
|
121
123
|
|
|
122
124
|
# accumulate loss
|
|
123
125
|
loss_accumulator = self.global_state.get('loss_accumulator', 0)
|
|
124
|
-
loss_accumulator += tofloat(
|
|
126
|
+
loss_accumulator += tofloat(objective.loss)
|
|
125
127
|
self.global_state['loss_accumulator'] = loss_accumulator
|
|
126
128
|
|
|
127
129
|
# on nth step, use the accumulated gradient
|
|
@@ -136,10 +138,10 @@ class SVRG(Module):
|
|
|
136
138
|
|
|
137
139
|
# otherwise skip update until enough grads are accumulated
|
|
138
140
|
else:
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
return
|
|
141
|
+
objective.updates = None
|
|
142
|
+
objective.stop = True
|
|
143
|
+
objective.skip_update = True
|
|
144
|
+
return
|
|
143
145
|
|
|
144
146
|
|
|
145
147
|
svrg_steps = self.defaults['svrg_steps']
|
|
@@ -194,7 +196,7 @@ class SVRG(Module):
|
|
|
194
196
|
|
|
195
197
|
return closure(False)
|
|
196
198
|
|
|
197
|
-
|
|
199
|
+
objective.closure = svrg_closure
|
|
198
200
|
|
|
199
201
|
# --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
|
|
200
202
|
if current_svrg_step >= svrg_steps:
|
|
@@ -203,6 +205,6 @@ class SVRG(Module):
|
|
|
203
205
|
del self.global_state['full_loss']
|
|
204
206
|
del self.global_state['x_0']
|
|
205
207
|
if self.defaults['reset_before_accum']:
|
|
206
|
-
|
|
208
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
207
209
|
|
|
208
|
-
|
|
210
|
+
def apply(self, objective): return objective
|
|
@@ -1 +1,2 @@
|
|
|
1
|
-
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
|
|
1
|
+
from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
|
|
2
|
+
from .reinit import RandomReinitialize
|