torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ from collections.abc import Iterable
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...utils.tensorlist import TensorList
|
|
7
|
-
from ...core import
|
|
7
|
+
from ...core import TensorTransform
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
|
|
@@ -55,7 +55,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
|
55
55
|
v[-1] = 1
|
|
56
56
|
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
57
|
|
|
58
|
-
class LaplacianSmoothing(
|
|
58
|
+
class LaplacianSmoothing(TensorTransform):
|
|
59
59
|
"""Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
@@ -70,29 +70,30 @@ class LaplacianSmoothing(Transform):
|
|
|
70
70
|
what to set on var.
|
|
71
71
|
|
|
72
72
|
Examples:
|
|
73
|
-
|
|
73
|
+
Laplacian Smoothing Gradient Descent optimizer as in the paper
|
|
74
74
|
|
|
75
|
-
|
|
75
|
+
```python
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
77
|
+
opt = tz.Modular(
|
|
78
|
+
model.parameters(),
|
|
79
|
+
tz.m.LaplacianSmoothing(),
|
|
80
|
+
tz.m.LR(1e-2),
|
|
81
|
+
)
|
|
82
|
+
```
|
|
82
83
|
|
|
83
84
|
Reference:
|
|
84
85
|
Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
|
|
85
86
|
|
|
86
87
|
"""
|
|
87
|
-
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4
|
|
88
|
+
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4):
|
|
88
89
|
defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
|
|
89
|
-
super().__init__(defaults
|
|
90
|
+
super().__init__(defaults)
|
|
90
91
|
# precomputed denominator for when layerwise=False
|
|
91
92
|
self.global_state['full_denominator'] = None
|
|
92
93
|
|
|
93
94
|
|
|
94
95
|
@torch.no_grad
|
|
95
|
-
def
|
|
96
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
96
97
|
layerwise = settings[0]['layerwise']
|
|
97
98
|
|
|
98
99
|
# layerwise laplacian smoothing
|
|
@@ -7,14 +7,15 @@ from typing import Literal, cast
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Modular, Module,
|
|
10
|
+
from ...core import Chainable, Modular, Module, Objective
|
|
11
11
|
from ...core.reformulation import Reformulation
|
|
12
12
|
from ...utils import Distributions, NumberList, TensorList
|
|
13
13
|
from ..termination import TerminationCriteriaBase, make_termination_criteria
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
def _reset_except_self(
|
|
17
|
-
|
|
16
|
+
def _reset_except_self(objective: Objective, modules, self: Module):
|
|
17
|
+
assert objective.modular is not None
|
|
18
|
+
for m in objective.modular.flat_modules:
|
|
18
19
|
if m is not self:
|
|
19
20
|
m.reset()
|
|
20
21
|
|
|
@@ -98,15 +99,15 @@ class GradientSampling(Reformulation):
|
|
|
98
99
|
self.set_child('termination', make_termination_criteria(extra=termination))
|
|
99
100
|
|
|
100
101
|
@torch.no_grad
|
|
101
|
-
def pre_step(self,
|
|
102
|
-
params = TensorList(
|
|
102
|
+
def pre_step(self, objective):
|
|
103
|
+
params = TensorList(objective.params)
|
|
103
104
|
|
|
104
105
|
fixed = self.defaults['fixed']
|
|
105
106
|
|
|
106
107
|
# check termination criteria
|
|
107
108
|
if 'termination' in self.children:
|
|
108
109
|
termination = cast(TerminationCriteriaBase, self.children['termination'])
|
|
109
|
-
if termination.should_terminate(
|
|
110
|
+
if termination.should_terminate(objective):
|
|
110
111
|
|
|
111
112
|
# decay sigmas
|
|
112
113
|
states = [self.state[p] for p in params]
|
|
@@ -118,7 +119,7 @@ class GradientSampling(Reformulation):
|
|
|
118
119
|
|
|
119
120
|
# reset on sigmas decay
|
|
120
121
|
if self.defaults['reset_on_termination']:
|
|
121
|
-
|
|
122
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
122
123
|
|
|
123
124
|
# clear perturbations
|
|
124
125
|
self.global_state.pop('perts', None)
|
|
@@ -136,7 +137,7 @@ class GradientSampling(Reformulation):
|
|
|
136
137
|
self.global_state['perts'] = perts
|
|
137
138
|
|
|
138
139
|
@torch.no_grad
|
|
139
|
-
def closure(self, backward, closure, params,
|
|
140
|
+
def closure(self, backward, closure, params, objective):
|
|
140
141
|
params = TensorList(params)
|
|
141
142
|
loss_agg = None
|
|
142
143
|
grad_agg = None
|
|
@@ -160,7 +161,7 @@ class GradientSampling(Reformulation):
|
|
|
160
161
|
|
|
161
162
|
# evaluate at x_0
|
|
162
163
|
if include_x0:
|
|
163
|
-
f_0 =
|
|
164
|
+
f_0 = objective.get_loss(backward=backward)
|
|
164
165
|
|
|
165
166
|
isfinite = math.isfinite(f_0)
|
|
166
167
|
if isfinite:
|
|
@@ -168,7 +169,7 @@ class GradientSampling(Reformulation):
|
|
|
168
169
|
loss_agg = f_0
|
|
169
170
|
|
|
170
171
|
if backward:
|
|
171
|
-
g_0 =
|
|
172
|
+
g_0 = objective.get_grads()
|
|
172
173
|
if isfinite: grad_agg = g_0
|
|
173
174
|
|
|
174
175
|
# evaluate at x_0 + p for each perturbation
|
|
@@ -5,9 +5,9 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable,
|
|
8
|
+
from ...core import Chainable, TensorTransform
|
|
9
9
|
from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
|
|
10
|
-
from ...
|
|
10
|
+
from ...linalg.linear_operator import ScaledIdentity
|
|
11
11
|
from ..functional import epsilon_step_size
|
|
12
12
|
|
|
13
13
|
def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
|
16
16
|
return False
|
|
17
17
|
return True
|
|
18
18
|
|
|
19
|
-
def _get_H(self:
|
|
19
|
+
def _get_H(self: TensorTransform, var):
|
|
20
20
|
n = sum(p.numel() for p in var.params)
|
|
21
21
|
p = var.params[0]
|
|
22
22
|
alpha = self.global_state.get('alpha', 1)
|
|
@@ -25,7 +25,7 @@ def _get_H(self: Transform, var):
|
|
|
25
25
|
return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class PolyakStepSize(
|
|
28
|
+
class PolyakStepSize(TensorTransform):
|
|
29
29
|
"""Polyak's subgradient method with known or unknown f*.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
@@ -47,7 +47,7 @@ class PolyakStepSize(Transform):
|
|
|
47
47
|
super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
50
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
51
51
|
assert grads is not None and loss is not None
|
|
52
52
|
tensors = TensorList(tensors)
|
|
53
53
|
grads = TensorList(grads)
|
|
@@ -79,15 +79,15 @@ class PolyakStepSize(Transform):
|
|
|
79
79
|
self.global_state['alpha'] = alpha
|
|
80
80
|
|
|
81
81
|
@torch.no_grad
|
|
82
|
-
def
|
|
82
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
83
83
|
alpha = self.global_state.get('alpha', 1)
|
|
84
84
|
if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
|
|
85
85
|
|
|
86
86
|
torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
|
|
87
87
|
return tensors
|
|
88
88
|
|
|
89
|
-
def get_H(self,
|
|
90
|
-
return _get_H(self,
|
|
89
|
+
def get_H(self, objective):
|
|
90
|
+
return _get_H(self, objective)
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def _bb_short(s: TensorList, y: TensorList, sy, eps):
|
|
@@ -116,7 +116,7 @@ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
|
|
|
116
116
|
return None
|
|
117
117
|
return (short * long) ** 0.5
|
|
118
118
|
|
|
119
|
-
class BarzilaiBorwein(
|
|
119
|
+
class BarzilaiBorwein(TensorTransform):
|
|
120
120
|
"""Barzilai-Borwein step size method.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
@@ -144,7 +144,7 @@ class BarzilaiBorwein(Transform):
|
|
|
144
144
|
self.global_state['reset'] = True
|
|
145
145
|
|
|
146
146
|
@torch.no_grad
|
|
147
|
-
def
|
|
147
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
148
148
|
step = self.global_state.get('step', 0)
|
|
149
149
|
self.global_state['step'] = step + 1
|
|
150
150
|
|
|
@@ -175,11 +175,11 @@ class BarzilaiBorwein(Transform):
|
|
|
175
175
|
prev_p.copy_(params)
|
|
176
176
|
prev_g.copy_(g)
|
|
177
177
|
|
|
178
|
-
def get_H(self,
|
|
179
|
-
return _get_H(self,
|
|
178
|
+
def get_H(self, objective):
|
|
179
|
+
return _get_H(self, objective)
|
|
180
180
|
|
|
181
181
|
@torch.no_grad
|
|
182
|
-
def
|
|
182
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
183
183
|
alpha = self.global_state.get('alpha', None)
|
|
184
184
|
|
|
185
185
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -189,7 +189,7 @@ class BarzilaiBorwein(Transform):
|
|
|
189
189
|
return tensors
|
|
190
190
|
|
|
191
191
|
|
|
192
|
-
class BBStab(
|
|
192
|
+
class BBStab(TensorTransform):
|
|
193
193
|
"""Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
|
|
194
194
|
|
|
195
195
|
This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
|
|
@@ -228,7 +228,7 @@ class BBStab(Transform):
|
|
|
228
228
|
self.global_state['reset'] = True
|
|
229
229
|
|
|
230
230
|
@torch.no_grad
|
|
231
|
-
def
|
|
231
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
232
232
|
step = self.global_state.get('step', 0)
|
|
233
233
|
self.global_state['step'] = step + 1
|
|
234
234
|
|
|
@@ -287,11 +287,11 @@ class BBStab(Transform):
|
|
|
287
287
|
prev_p.copy_(params)
|
|
288
288
|
prev_g.copy_(g)
|
|
289
289
|
|
|
290
|
-
def get_H(self,
|
|
291
|
-
return _get_H(self,
|
|
290
|
+
def get_H(self, objective):
|
|
291
|
+
return _get_H(self, objective)
|
|
292
292
|
|
|
293
293
|
@torch.no_grad
|
|
294
|
-
def
|
|
294
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
295
295
|
alpha = self.global_state.get('alpha', None)
|
|
296
296
|
|
|
297
297
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -301,7 +301,7 @@ class BBStab(Transform):
|
|
|
301
301
|
return tensors
|
|
302
302
|
|
|
303
303
|
|
|
304
|
-
class AdGD(
|
|
304
|
+
class AdGD(TensorTransform):
|
|
305
305
|
"""AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
|
|
306
306
|
def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
|
|
307
307
|
defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
|
|
@@ -313,7 +313,7 @@ class AdGD(Transform):
|
|
|
313
313
|
self.global_state['reset'] = True
|
|
314
314
|
|
|
315
315
|
@torch.no_grad
|
|
316
|
-
def
|
|
316
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
317
317
|
variant = settings[0]['variant']
|
|
318
318
|
theta_0 = 0 if variant == 1 else 1/3
|
|
319
319
|
theta = self.global_state.get('theta', theta_0)
|
|
@@ -371,7 +371,7 @@ class AdGD(Transform):
|
|
|
371
371
|
prev_g.copy_(g)
|
|
372
372
|
|
|
373
373
|
@torch.no_grad
|
|
374
|
-
def
|
|
374
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
375
375
|
alpha = self.global_state.get('alpha', None)
|
|
376
376
|
|
|
377
377
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -383,5 +383,5 @@ class AdGD(Transform):
|
|
|
383
383
|
torch._foreach_mul_(tensors, alpha)
|
|
384
384
|
return tensors
|
|
385
385
|
|
|
386
|
-
def get_H(self,
|
|
387
|
-
return _get_H(self,
|
|
386
|
+
def get_H(self, objective):
|
|
387
|
+
return _get_H(self, objective)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import torch
|
|
3
3
|
import random
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
5
|
+
from ...core import TensorTransform
|
|
6
6
|
from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
|
|
7
7
|
|
|
8
8
|
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
@@ -12,24 +12,24 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
|
12
12
|
return tensors * lr
|
|
13
13
|
return tensors
|
|
14
14
|
|
|
15
|
-
class LR(
|
|
15
|
+
class LR(TensorTransform):
|
|
16
16
|
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
17
17
|
def __init__(self, lr: float):
|
|
18
18
|
defaults=dict(lr=lr)
|
|
19
|
-
super().__init__(defaults
|
|
19
|
+
super().__init__(defaults)
|
|
20
20
|
|
|
21
21
|
@torch.no_grad
|
|
22
|
-
def
|
|
22
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
23
23
|
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
24
24
|
|
|
25
|
-
class StepSize(
|
|
25
|
+
class StepSize(TensorTransform):
|
|
26
26
|
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
27
27
|
def __init__(self, step_size: float, key = 'step_size'):
|
|
28
28
|
defaults={"key": key, key: step_size}
|
|
29
|
-
super().__init__(defaults
|
|
29
|
+
super().__init__(defaults)
|
|
30
30
|
|
|
31
31
|
@torch.no_grad
|
|
32
|
-
def
|
|
32
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
33
33
|
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
34
34
|
|
|
35
35
|
|
|
@@ -38,8 +38,8 @@ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberLi
|
|
|
38
38
|
if step > steps: return end_lr
|
|
39
39
|
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
40
40
|
|
|
41
|
-
class Warmup(
|
|
42
|
-
"""Learning rate warmup, linearly increases learning rate multiplier from
|
|
41
|
+
class Warmup(TensorTransform):
|
|
42
|
+
"""Learning rate warmup, linearly increases learning rate multiplier from ``start_lr`` to ``end_lr`` over ``steps`` steps.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
@@ -64,7 +64,7 @@ class Warmup(Transform):
|
|
|
64
64
|
super().__init__(defaults, uses_grad=False)
|
|
65
65
|
|
|
66
66
|
@torch.no_grad
|
|
67
|
-
def
|
|
67
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
68
68
|
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
69
69
|
num_steps = settings[0]['steps']
|
|
70
70
|
step = self.global_state.get('step', 0)
|
|
@@ -77,7 +77,7 @@ class Warmup(Transform):
|
|
|
77
77
|
self.global_state['step'] = step + 1
|
|
78
78
|
return tensors
|
|
79
79
|
|
|
80
|
-
class WarmupNormClip(
|
|
80
|
+
class WarmupNormClip(TensorTransform):
|
|
81
81
|
"""Warmup via clipping of the update norm.
|
|
82
82
|
|
|
83
83
|
Args:
|
|
@@ -102,7 +102,7 @@ class WarmupNormClip(Transform):
|
|
|
102
102
|
super().__init__(defaults, uses_grad=False)
|
|
103
103
|
|
|
104
104
|
@torch.no_grad
|
|
105
|
-
def
|
|
105
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
106
106
|
start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
|
|
107
107
|
num_steps = settings[0]['steps']
|
|
108
108
|
step = self.global_state.get('step', 0)
|
|
@@ -118,8 +118,8 @@ class WarmupNormClip(Transform):
|
|
|
118
118
|
return tensors
|
|
119
119
|
|
|
120
120
|
|
|
121
|
-
class RandomStepSize(
|
|
122
|
-
"""Uses random global or layer-wise step size from
|
|
121
|
+
class RandomStepSize(TensorTransform):
|
|
122
|
+
"""Uses random global or layer-wise step size from ``low`` to ``high``.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
125
|
low (float, optional): minimum learning rate. Defaults to 0.
|
|
@@ -133,7 +133,7 @@ class RandomStepSize(Transform):
|
|
|
133
133
|
super().__init__(defaults, uses_grad=False)
|
|
134
134
|
|
|
135
135
|
@torch.no_grad
|
|
136
|
-
def
|
|
136
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
137
137
|
s = settings[0]
|
|
138
138
|
parameterwise = s['parameterwise']
|
|
139
139
|
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import time
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import cast
|
|
4
|
+
from typing import cast, final
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Objective
|
|
9
9
|
from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
|
|
10
10
|
|
|
11
11
|
|
|
@@ -16,14 +16,15 @@ class TerminationCriteriaBase(Module):
|
|
|
16
16
|
super().__init__(defaults)
|
|
17
17
|
|
|
18
18
|
@abstractmethod
|
|
19
|
-
def termination_criteria(self,
|
|
19
|
+
def termination_criteria(self, objective: Objective) -> bool:
|
|
20
20
|
...
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
@final
|
|
23
|
+
def should_terminate(self, objective: Objective) -> bool:
|
|
23
24
|
n_bad = self.global_state.get('_n_bad', 0)
|
|
24
25
|
n = self.defaults['_n']
|
|
25
26
|
|
|
26
|
-
if self.termination_criteria(
|
|
27
|
+
if self.termination_criteria(objective):
|
|
27
28
|
n_bad += 1
|
|
28
29
|
if n_bad >= n:
|
|
29
30
|
self.global_state['_n_bad'] = 0
|
|
@@ -36,12 +37,12 @@ class TerminationCriteriaBase(Module):
|
|
|
36
37
|
return False
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
def update(self,
|
|
40
|
-
|
|
41
|
-
if
|
|
40
|
+
def update(self, objective):
|
|
41
|
+
objective.should_terminate = self.should_terminate(objective)
|
|
42
|
+
if objective.should_terminate: self.global_state['_n_bad'] = 0
|
|
42
43
|
|
|
43
|
-
def apply(self,
|
|
44
|
-
return
|
|
44
|
+
def apply(self, objective):
|
|
45
|
+
return objective
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
class TerminateAfterNSteps(TerminationCriteriaBase):
|
|
@@ -49,7 +50,7 @@ class TerminateAfterNSteps(TerminationCriteriaBase):
|
|
|
49
50
|
defaults = dict(steps=steps)
|
|
50
51
|
super().__init__(defaults)
|
|
51
52
|
|
|
52
|
-
def termination_criteria(self,
|
|
53
|
+
def termination_criteria(self, objective):
|
|
53
54
|
step = self.global_state.get('step', 0)
|
|
54
55
|
self.global_state['step'] = step + 1
|
|
55
56
|
|
|
@@ -61,16 +62,17 @@ class TerminateAfterNEvaluations(TerminationCriteriaBase):
|
|
|
61
62
|
defaults = dict(maxevals=maxevals)
|
|
62
63
|
super().__init__(defaults)
|
|
63
64
|
|
|
64
|
-
def termination_criteria(self,
|
|
65
|
+
def termination_criteria(self, objective):
|
|
65
66
|
maxevals = self.defaults['maxevals']
|
|
66
|
-
|
|
67
|
+
assert objective.modular is not None
|
|
68
|
+
return objective.modular.num_evaluations >= maxevals
|
|
67
69
|
|
|
68
70
|
class TerminateAfterNSeconds(TerminationCriteriaBase):
|
|
69
71
|
def __init__(self, seconds:float, sec_fn = time.time):
|
|
70
72
|
defaults = dict(seconds=seconds, sec_fn=sec_fn)
|
|
71
73
|
super().__init__(defaults)
|
|
72
74
|
|
|
73
|
-
def termination_criteria(self,
|
|
75
|
+
def termination_criteria(self, objective):
|
|
74
76
|
max_seconds = self.defaults['seconds']
|
|
75
77
|
sec_fn = self.defaults['sec_fn']
|
|
76
78
|
|
|
@@ -88,10 +90,10 @@ class TerminateByGradientNorm(TerminationCriteriaBase):
|
|
|
88
90
|
defaults = dict(tol=tol, ord=ord)
|
|
89
91
|
super().__init__(defaults, n=n)
|
|
90
92
|
|
|
91
|
-
def termination_criteria(self,
|
|
93
|
+
def termination_criteria(self, objective):
|
|
92
94
|
tol = self.defaults['tol']
|
|
93
95
|
ord = self.defaults['ord']
|
|
94
|
-
return TensorList(
|
|
96
|
+
return TensorList(objective.get_grads()).global_metric(ord) <= tol
|
|
95
97
|
|
|
96
98
|
|
|
97
99
|
class TerminateByUpdateNorm(TerminationCriteriaBase):
|
|
@@ -100,20 +102,20 @@ class TerminateByUpdateNorm(TerminationCriteriaBase):
|
|
|
100
102
|
defaults = dict(tol=tol, ord=ord)
|
|
101
103
|
super().__init__(defaults, n=n)
|
|
102
104
|
|
|
103
|
-
def termination_criteria(self,
|
|
105
|
+
def termination_criteria(self, objective):
|
|
104
106
|
step = self.global_state.get('step', 0)
|
|
105
107
|
self.global_state['step'] = step + 1
|
|
106
108
|
|
|
107
109
|
tol = self.defaults['tol']
|
|
108
110
|
ord = self.defaults['ord']
|
|
109
111
|
|
|
110
|
-
p_prev = self.get_state(
|
|
112
|
+
p_prev = self.get_state(objective.params, 'p_prev', cls=TensorList)
|
|
111
113
|
if step == 0:
|
|
112
|
-
p_prev.copy_(
|
|
114
|
+
p_prev.copy_(objective.params)
|
|
113
115
|
return False
|
|
114
116
|
|
|
115
|
-
should_terminate = (p_prev -
|
|
116
|
-
p_prev.copy_(
|
|
117
|
+
should_terminate = (p_prev - objective.params).global_metric(ord) <= tol
|
|
118
|
+
p_prev.copy_(objective.params)
|
|
117
119
|
return should_terminate
|
|
118
120
|
|
|
119
121
|
|
|
@@ -122,10 +124,10 @@ class TerminateOnNoImprovement(TerminationCriteriaBase):
|
|
|
122
124
|
defaults = dict(tol=tol)
|
|
123
125
|
super().__init__(defaults, n=n)
|
|
124
126
|
|
|
125
|
-
def termination_criteria(self,
|
|
127
|
+
def termination_criteria(self, objective):
|
|
126
128
|
tol = self.defaults['tol']
|
|
127
129
|
|
|
128
|
-
f = tofloat(
|
|
130
|
+
f = tofloat(objective.get_loss(False))
|
|
129
131
|
if 'f_min' not in self.global_state:
|
|
130
132
|
self.global_state['f_min'] = f
|
|
131
133
|
return False
|
|
@@ -141,9 +143,9 @@ class TerminateOnLossReached(TerminationCriteriaBase):
|
|
|
141
143
|
defaults = dict(value=value)
|
|
142
144
|
super().__init__(defaults)
|
|
143
145
|
|
|
144
|
-
def termination_criteria(self,
|
|
146
|
+
def termination_criteria(self, objective):
|
|
145
147
|
value = self.defaults['value']
|
|
146
|
-
return
|
|
148
|
+
return objective.get_loss(False) <= value
|
|
147
149
|
|
|
148
150
|
class TerminateAny(TerminationCriteriaBase):
|
|
149
151
|
def __init__(self, *criteria: TerminationCriteriaBase):
|
|
@@ -151,9 +153,9 @@ class TerminateAny(TerminationCriteriaBase):
|
|
|
151
153
|
|
|
152
154
|
self.set_children_sequence(criteria)
|
|
153
155
|
|
|
154
|
-
def termination_criteria(self,
|
|
156
|
+
def termination_criteria(self, objective: Objective) -> bool:
|
|
155
157
|
for c in self.get_children_sequence():
|
|
156
|
-
if cast(TerminationCriteriaBase, c).termination_criteria(
|
|
158
|
+
if cast(TerminationCriteriaBase, c).termination_criteria(objective): return True
|
|
157
159
|
|
|
158
160
|
return False
|
|
159
161
|
|
|
@@ -163,9 +165,9 @@ class TerminateAll(TerminationCriteriaBase):
|
|
|
163
165
|
|
|
164
166
|
self.set_children_sequence(criteria)
|
|
165
167
|
|
|
166
|
-
def termination_criteria(self,
|
|
168
|
+
def termination_criteria(self, objective: Objective) -> bool:
|
|
167
169
|
for c in self.get_children_sequence():
|
|
168
|
-
if not cast(TerminationCriteriaBase, c).termination_criteria(
|
|
170
|
+
if not cast(TerminationCriteriaBase, c).termination_criteria(objective): return False
|
|
169
171
|
|
|
170
172
|
return True
|
|
171
173
|
|
|
@@ -173,7 +175,7 @@ class TerminateNever(TerminationCriteriaBase):
|
|
|
173
175
|
def __init__(self):
|
|
174
176
|
super().__init__()
|
|
175
177
|
|
|
176
|
-
def termination_criteria(self,
|
|
178
|
+
def termination_criteria(self, objective): return False
|
|
177
179
|
|
|
178
180
|
def make_termination_criteria(
|
|
179
181
|
ftol: float | None = None,
|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
from ...core import Chainable, Module
|
|
7
7
|
from ...utils import TensorList, vec_to_tensors
|
|
8
|
-
from ...
|
|
8
|
+
from ...linalg.linear_operator import LinearOperator
|
|
9
9
|
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
10
10
|
|
|
11
11
|
|
|
@@ -58,7 +58,7 @@ def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_param
|
|
|
58
58
|
for _ in range(it_max):
|
|
59
59
|
r_try = (r_min + r_max) / 2
|
|
60
60
|
lam = r_try * M
|
|
61
|
-
s_lam = H.
|
|
61
|
+
s_lam = H.solve_plus_diag(g, lam).neg()
|
|
62
62
|
# s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
|
|
63
63
|
solver_it += 1
|
|
64
64
|
crit = conv_criterion(s_lam, r_try)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
4
|
from ...core import Chainable, Module
|
|
5
|
-
from ...
|
|
5
|
+
from ...linalg import linear_operator
|
|
6
6
|
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
7
7
|
|
|
8
8
|
|
|
@@ -32,38 +32,31 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
32
32
|
max_attempts (max_attempts, optional):
|
|
33
33
|
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
34
34
|
this limit is exceeded. Defaults to 10.
|
|
35
|
+
adaptive (bool, optional):
|
|
36
|
+
if True, trust radius is multiplied by square root of gradient norm.
|
|
35
37
|
fallback (bool, optional):
|
|
36
38
|
if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
|
|
37
39
|
be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
|
|
38
40
|
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
39
41
|
|
|
40
|
-
Examples:
|
|
41
|
-
Gauss-Newton with Levenberg-Marquardt trust-region
|
|
42
|
+
### Examples:
|
|
42
43
|
|
|
43
|
-
|
|
44
|
+
Gauss-Newton with Levenberg-Marquardt trust-region
|
|
44
45
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
```python
|
|
47
|
+
opt = tz.Modular(
|
|
48
|
+
model.parameters(),
|
|
49
|
+
tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
|
|
50
|
+
)
|
|
51
|
+
```
|
|
49
52
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
First order trust region (hessian is assumed to be identity)
|
|
60
|
-
|
|
61
|
-
.. code-block:: python
|
|
62
|
-
|
|
63
|
-
opt = tz.Modular(
|
|
64
|
-
model.parameters(),
|
|
65
|
-
tz.m.LevenbergMarquardt(tz.m.Identity()),
|
|
66
|
-
)
|
|
53
|
+
LM-SR1
|
|
54
|
+
```python
|
|
55
|
+
opt = tz.Modular(
|
|
56
|
+
model.parameters(),
|
|
57
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
58
|
+
)
|
|
59
|
+
```
|
|
67
60
|
|
|
68
61
|
"""
|
|
69
62
|
def __init__(
|
|
@@ -78,11 +71,12 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
78
71
|
max_attempts: int = 10,
|
|
79
72
|
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
80
73
|
y: float = 0,
|
|
74
|
+
adaptive: bool = False,
|
|
81
75
|
fallback: bool = False,
|
|
82
76
|
update_freq: int = 1,
|
|
83
77
|
inner: Chainable | None = None,
|
|
84
78
|
):
|
|
85
|
-
defaults = dict(y=y, fallback=fallback)
|
|
79
|
+
defaults = dict(y=y, fallback=fallback, adaptive=adaptive)
|
|
86
80
|
super().__init__(
|
|
87
81
|
defaults=defaults,
|
|
88
82
|
hess_module=hess_module,
|
|
@@ -103,6 +97,7 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
103
97
|
|
|
104
98
|
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
105
99
|
y = settings['y']
|
|
100
|
+
adaptive = settings["adaptive"]
|
|
106
101
|
|
|
107
102
|
if isinstance(H, linear_operator.DenseInverse):
|
|
108
103
|
if settings['fallback']:
|
|
@@ -117,12 +112,14 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
117
112
|
)
|
|
118
113
|
|
|
119
114
|
reg = 1/radius
|
|
115
|
+
if adaptive: reg = reg * torch.linalg.vector_norm(g).sqrt()
|
|
116
|
+
|
|
120
117
|
if y == 0:
|
|
121
|
-
return H.
|
|
118
|
+
return H.solve_plus_diag(g, reg) # pyright:ignore[reportAttributeAccessIssue]
|
|
122
119
|
|
|
123
120
|
diag = H.diagonal()
|
|
124
121
|
diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
|
|
125
122
|
if y != 1: diag = (diag*y) + (1-y)
|
|
126
|
-
return H.
|
|
123
|
+
return H.solve_plus_diag(g, diag*reg)
|
|
127
124
|
|
|
128
125
|
|