torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/utility.py
CHANGED
|
@@ -2,78 +2,78 @@ from collections import deque
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module,
|
|
5
|
+
from ...core import Module, Transform
|
|
6
6
|
from ...utils.tensorlist import Distributions, TensorList
|
|
7
|
-
from ...
|
|
7
|
+
from ...linalg.linear_operator import ScaledIdentity
|
|
8
8
|
|
|
9
9
|
class Clone(Module):
|
|
10
10
|
"""Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
|
|
11
11
|
def __init__(self):
|
|
12
12
|
super().__init__({})
|
|
13
13
|
@torch.no_grad
|
|
14
|
-
def
|
|
15
|
-
|
|
16
|
-
return
|
|
14
|
+
def apply(self, objective):
|
|
15
|
+
objective.updates = [u.clone() for u in objective.get_updates()]
|
|
16
|
+
return objective
|
|
17
17
|
|
|
18
18
|
class Grad(Module):
|
|
19
19
|
"""Outputs the gradient"""
|
|
20
20
|
def __init__(self):
|
|
21
21
|
super().__init__({})
|
|
22
22
|
@torch.no_grad
|
|
23
|
-
def
|
|
24
|
-
|
|
25
|
-
return
|
|
23
|
+
def apply(self, objective):
|
|
24
|
+
objective.updates = [g.clone() for g in objective.get_grads()]
|
|
25
|
+
return objective
|
|
26
26
|
|
|
27
27
|
class Params(Module):
|
|
28
28
|
"""Outputs parameters"""
|
|
29
29
|
def __init__(self):
|
|
30
30
|
super().__init__({})
|
|
31
31
|
@torch.no_grad
|
|
32
|
-
def
|
|
33
|
-
|
|
34
|
-
return
|
|
32
|
+
def apply(self, objective):
|
|
33
|
+
objective.updates = [p.clone() for p in objective.params]
|
|
34
|
+
return objective
|
|
35
35
|
|
|
36
36
|
class Zeros(Module):
|
|
37
37
|
"""Outputs zeros"""
|
|
38
38
|
def __init__(self):
|
|
39
39
|
super().__init__({})
|
|
40
40
|
@torch.no_grad
|
|
41
|
-
def
|
|
42
|
-
|
|
43
|
-
return
|
|
41
|
+
def apply(self, objective):
|
|
42
|
+
objective.updates = [torch.zeros_like(p) for p in objective.params]
|
|
43
|
+
return objective
|
|
44
44
|
|
|
45
45
|
class Ones(Module):
|
|
46
46
|
"""Outputs ones"""
|
|
47
47
|
def __init__(self):
|
|
48
48
|
super().__init__({})
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
51
|
-
|
|
52
|
-
return
|
|
50
|
+
def apply(self, objective):
|
|
51
|
+
objective.updates = [torch.ones_like(p) for p in objective.params]
|
|
52
|
+
return objective
|
|
53
53
|
|
|
54
54
|
class Fill(Module):
|
|
55
|
-
"""Outputs tensors filled with
|
|
55
|
+
"""Outputs tensors filled with ``value``"""
|
|
56
56
|
def __init__(self, value: float):
|
|
57
57
|
defaults = dict(value=value)
|
|
58
58
|
super().__init__(defaults)
|
|
59
59
|
|
|
60
60
|
@torch.no_grad
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
return
|
|
61
|
+
def apply(self, objective):
|
|
62
|
+
objective.updates = [torch.full_like(p, self.settings[p]['value']) for p in objective.params]
|
|
63
|
+
return objective
|
|
64
64
|
|
|
65
65
|
class RandomSample(Module):
|
|
66
|
-
"""Outputs tensors filled with random numbers from distribution depending on value of
|
|
66
|
+
"""Outputs tensors filled with random numbers from distribution depending on value of ``distribution``."""
|
|
67
67
|
def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
|
|
68
68
|
defaults = dict(distribution=distribution, variance=variance)
|
|
69
69
|
super().__init__(defaults)
|
|
70
70
|
|
|
71
71
|
@torch.no_grad
|
|
72
|
-
def
|
|
72
|
+
def apply(self, objective):
|
|
73
73
|
distribution = self.defaults['distribution']
|
|
74
|
-
variance = self.get_settings(
|
|
75
|
-
|
|
76
|
-
return
|
|
74
|
+
variance = self.get_settings(objective.params, 'variance')
|
|
75
|
+
objective.updates = TensorList(objective.params).sample_like(distribution=distribution, variance=variance)
|
|
76
|
+
return objective
|
|
77
77
|
|
|
78
78
|
class Randn(Module):
|
|
79
79
|
"""Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
|
|
@@ -81,43 +81,44 @@ class Randn(Module):
|
|
|
81
81
|
super().__init__({})
|
|
82
82
|
|
|
83
83
|
@torch.no_grad
|
|
84
|
-
def
|
|
85
|
-
|
|
86
|
-
return
|
|
84
|
+
def apply(self, objective):
|
|
85
|
+
objective.updates = [torch.randn_like(p) for p in objective.params]
|
|
86
|
+
return objective
|
|
87
87
|
|
|
88
88
|
class Uniform(Module):
|
|
89
|
-
"""Outputs tensors filled with random numbers from uniform distribution between
|
|
89
|
+
"""Outputs tensors filled with random numbers from uniform distribution between ``low`` and ``high``."""
|
|
90
90
|
def __init__(self, low: float, high: float):
|
|
91
91
|
defaults = dict(low=low, high=high)
|
|
92
92
|
super().__init__(defaults)
|
|
93
93
|
|
|
94
94
|
@torch.no_grad
|
|
95
|
-
def
|
|
96
|
-
low,high = self.get_settings(
|
|
97
|
-
|
|
98
|
-
return
|
|
95
|
+
def apply(self, objective):
|
|
96
|
+
low,high = self.get_settings(objective.params, 'low','high')
|
|
97
|
+
objective.updates = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(objective.params, low, high)]
|
|
98
|
+
return objective
|
|
99
99
|
|
|
100
100
|
class GradToNone(Module):
|
|
101
|
-
"""Sets
|
|
101
|
+
"""Sets ``grad`` attribute to None on ``objective``."""
|
|
102
102
|
def __init__(self): super().__init__()
|
|
103
|
-
def
|
|
104
|
-
|
|
105
|
-
return
|
|
103
|
+
def apply(self, objective):
|
|
104
|
+
objective.grads = None
|
|
105
|
+
return objective
|
|
106
106
|
|
|
107
107
|
class UpdateToNone(Module):
|
|
108
|
-
"""Sets
|
|
108
|
+
"""Sets ``update`` attribute to None on ``var``."""
|
|
109
109
|
def __init__(self): super().__init__()
|
|
110
|
-
def
|
|
111
|
-
|
|
112
|
-
return
|
|
110
|
+
def apply(self, objective):
|
|
111
|
+
objective.updates = None
|
|
112
|
+
return objective
|
|
113
113
|
|
|
114
114
|
class Identity(Module):
|
|
115
115
|
"""Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
|
|
116
116
|
def __init__(self, *args, **kwargs): super().__init__()
|
|
117
|
-
def
|
|
118
|
-
def
|
|
119
|
-
|
|
120
|
-
|
|
117
|
+
def update(self, objective): pass
|
|
118
|
+
def apply(self, objective): return objective
|
|
119
|
+
def get_H(self, objective):
|
|
120
|
+
n = sum(p.numel() for p in objective.params)
|
|
121
|
+
p = objective.params[0]
|
|
121
122
|
return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
|
|
122
123
|
|
|
123
124
|
Noop = Identity
|
|
@@ -30,7 +30,7 @@ def debiased_step_size(
|
|
|
30
30
|
pow: float = 2,
|
|
31
31
|
alpha: float | NumberList = 1,
|
|
32
32
|
):
|
|
33
|
-
"""returns multiplier to step size"""
|
|
33
|
+
"""returns multiplier to step size, step starts from 1"""
|
|
34
34
|
if isinstance(beta1, NumberList): beta1 = beta1.fill_none(0)
|
|
35
35
|
if isinstance(beta2, NumberList): beta2 = beta2.fill_none(0)
|
|
36
36
|
|
|
@@ -8,7 +8,7 @@ from typing import Any, Literal
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from ...core import Chainable, Module,
|
|
11
|
+
from ...core import Chainable, Module, Objective
|
|
12
12
|
from ...utils import set_storage_, vec_to_tensors
|
|
13
13
|
|
|
14
14
|
|
|
@@ -80,7 +80,7 @@ class _FakeProjectedClosure:
|
|
|
80
80
|
class ProjectionBase(Module, ABC):
|
|
81
81
|
"""
|
|
82
82
|
Base class for projections.
|
|
83
|
-
This is an abstract class, to use it, subclass it and override
|
|
83
|
+
This is an abstract class, to use it, subclass it and override ``project`` and ``unproject``.
|
|
84
84
|
|
|
85
85
|
Args:
|
|
86
86
|
modules (Chainable): modules that will be applied in the projected domain.
|
|
@@ -149,9 +149,12 @@ class ProjectionBase(Module, ABC):
|
|
|
149
149
|
Iterable[torch.Tensor]: unprojected tensors of the same shape as params
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
|
+
def update(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
|
|
153
|
+
def apply(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
|
|
154
|
+
|
|
152
155
|
@torch.no_grad
|
|
153
|
-
def step(self,
|
|
154
|
-
params =
|
|
156
|
+
def step(self, objective: Objective):
|
|
157
|
+
params = objective.params
|
|
155
158
|
settings = [self.settings[p] for p in params]
|
|
156
159
|
|
|
157
160
|
def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
|
|
@@ -159,16 +162,16 @@ class ProjectionBase(Module, ABC):
|
|
|
159
162
|
return list(self.project(
|
|
160
163
|
tensors=tensors,
|
|
161
164
|
params=params,
|
|
162
|
-
grads=
|
|
163
|
-
loss=
|
|
165
|
+
grads=objective.grads,
|
|
166
|
+
loss=objective.loss,
|
|
164
167
|
states=states,
|
|
165
168
|
settings=settings,
|
|
166
169
|
current=current,
|
|
167
170
|
))
|
|
168
171
|
|
|
169
|
-
|
|
172
|
+
projected_obj = objective.clone(clone_updates=False, parent=objective)
|
|
170
173
|
|
|
171
|
-
closure =
|
|
174
|
+
closure = objective.closure
|
|
172
175
|
|
|
173
176
|
# if this is True, update and grad were projected simultaneously under current="grads"
|
|
174
177
|
# so update will have to be unprojected with current="grads"
|
|
@@ -179,9 +182,9 @@ class ProjectionBase(Module, ABC):
|
|
|
179
182
|
# but if it has already been computed, it should be projected
|
|
180
183
|
if self._project_params and closure is not None:
|
|
181
184
|
|
|
182
|
-
if self._project_update and
|
|
185
|
+
if self._project_update and objective.updates is not None:
|
|
183
186
|
# project update only if it already exists
|
|
184
|
-
|
|
187
|
+
projected_obj.updates = _project(objective.updates, current='update')
|
|
185
188
|
|
|
186
189
|
else:
|
|
187
190
|
# update will be set to gradients on var.get_grad()
|
|
@@ -189,43 +192,43 @@ class ProjectionBase(Module, ABC):
|
|
|
189
192
|
update_is_grad = True
|
|
190
193
|
|
|
191
194
|
# project grad only if it already exists
|
|
192
|
-
if self._project_grad and
|
|
193
|
-
|
|
195
|
+
if self._project_grad and objective.grads is not None:
|
|
196
|
+
projected_obj.grads = _project(objective.grads, current='grads')
|
|
194
197
|
|
|
195
198
|
# otherwise update/grad needs to be calculated and projected here
|
|
196
199
|
else:
|
|
197
200
|
if self._project_update:
|
|
198
|
-
if
|
|
201
|
+
if objective.updates is None:
|
|
199
202
|
# update is None, meaning it will be set to `grad`.
|
|
200
203
|
# we can project grad and use it for update
|
|
201
|
-
grad =
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
del
|
|
204
|
+
grad = objective.get_grads()
|
|
205
|
+
projected_obj.grads = _project(grad, current='grads')
|
|
206
|
+
projected_obj.updates = [g.clone() for g in projected_obj.grads]
|
|
207
|
+
del objective.updates
|
|
205
208
|
update_is_grad = True
|
|
206
209
|
|
|
207
210
|
else:
|
|
208
211
|
# update exists so it needs to be projected
|
|
209
|
-
update =
|
|
210
|
-
|
|
211
|
-
del update,
|
|
212
|
+
update = objective.get_updates()
|
|
213
|
+
projected_obj.updates = _project(update, current='update')
|
|
214
|
+
del update, objective.updates
|
|
212
215
|
|
|
213
|
-
if self._project_grad and
|
|
216
|
+
if self._project_grad and projected_obj.grads is None:
|
|
214
217
|
# projected_vars.grad may have been projected simultaneously with update
|
|
215
218
|
# but if that didn't happen, it is projected here
|
|
216
|
-
grad =
|
|
217
|
-
|
|
219
|
+
grad = objective.get_grads()
|
|
220
|
+
projected_obj.grads = _project(grad, current='grads')
|
|
218
221
|
|
|
219
222
|
|
|
220
223
|
original_params = None
|
|
221
224
|
if self._project_params:
|
|
222
|
-
original_params = [p.clone() for p in
|
|
223
|
-
projected_params = _project(
|
|
225
|
+
original_params = [p.clone() for p in objective.params]
|
|
226
|
+
projected_params = _project(objective.params, current='params')
|
|
224
227
|
|
|
225
228
|
else:
|
|
226
229
|
# make fake params for correct shapes and state storage
|
|
227
230
|
# they reuse update or grad storage for memory efficiency
|
|
228
|
-
projected_params =
|
|
231
|
+
projected_params = projected_obj.updates if projected_obj.updates is not None else projected_obj.grads
|
|
229
232
|
assert projected_params is not None
|
|
230
233
|
|
|
231
234
|
if self._projected_params is None:
|
|
@@ -245,8 +248,8 @@ class ProjectionBase(Module, ABC):
|
|
|
245
248
|
return list(self.unproject(
|
|
246
249
|
projected_tensors=projected_tensors,
|
|
247
250
|
params=params,
|
|
248
|
-
grads=
|
|
249
|
-
loss=
|
|
251
|
+
grads=objective.grads,
|
|
252
|
+
loss=objective.loss,
|
|
250
253
|
states=states,
|
|
251
254
|
settings=settings,
|
|
252
255
|
current=current,
|
|
@@ -254,19 +257,19 @@ class ProjectionBase(Module, ABC):
|
|
|
254
257
|
|
|
255
258
|
# project closure
|
|
256
259
|
if self._project_params:
|
|
257
|
-
|
|
260
|
+
projected_obj.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
|
|
258
261
|
params=params, projected_params=projected_params)
|
|
259
262
|
|
|
260
263
|
elif closure is not None:
|
|
261
|
-
|
|
264
|
+
projected_obj.closure = _FakeProjectedClosure(closure, project_fn=_project,
|
|
262
265
|
params=params, fake_params=projected_params)
|
|
263
266
|
|
|
264
267
|
else:
|
|
265
|
-
|
|
268
|
+
projected_obj.closure = None
|
|
266
269
|
|
|
267
270
|
# ----------------------------------- step ----------------------------------- #
|
|
268
|
-
|
|
269
|
-
|
|
271
|
+
projected_obj.params = projected_params
|
|
272
|
+
projected_obj = self.children['modules'].step(projected_obj)
|
|
270
273
|
|
|
271
274
|
# empty fake params storage
|
|
272
275
|
# this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
|
|
@@ -275,24 +278,24 @@ class ProjectionBase(Module, ABC):
|
|
|
275
278
|
set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
|
|
276
279
|
|
|
277
280
|
# --------------------------------- unproject -------------------------------- #
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
281
|
+
unprojected_obj = projected_obj.clone(clone_updates=False)
|
|
282
|
+
unprojected_obj.closure = objective.closure
|
|
283
|
+
unprojected_obj.params = objective.params
|
|
284
|
+
unprojected_obj.grads = objective.grads # this may also be set by projected_var since it has var as parent
|
|
282
285
|
|
|
283
286
|
if self._project_update:
|
|
284
|
-
assert
|
|
285
|
-
|
|
286
|
-
del
|
|
287
|
+
assert projected_obj.updates is not None
|
|
288
|
+
unprojected_obj.updates = _unproject(projected_obj.updates, current='grads' if update_is_grad else 'update')
|
|
289
|
+
del projected_obj.updates
|
|
287
290
|
|
|
288
|
-
del
|
|
291
|
+
del projected_obj
|
|
289
292
|
|
|
290
293
|
# original params are stored if params are projected
|
|
291
294
|
if original_params is not None:
|
|
292
|
-
for p, o in zip(
|
|
295
|
+
for p, o in zip(unprojected_obj.params, original_params):
|
|
293
296
|
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
294
297
|
|
|
295
|
-
return
|
|
298
|
+
return unprojected_obj
|
|
296
299
|
|
|
297
300
|
|
|
298
301
|
|
|
@@ -4,8 +4,8 @@ from typing import Literal, Protocol, overload
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...utils import TensorList
|
|
7
|
-
from ...
|
|
8
|
-
from ..
|
|
7
|
+
from ...linalg.linear_operator import DenseInverse, LinearOperator
|
|
8
|
+
from ..opt_utils import safe_clip
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class DampingStrategy(Protocol):
|
|
@@ -4,10 +4,10 @@ from typing import overload
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable,
|
|
7
|
+
from ...core import Chainable, TensorTransform
|
|
8
8
|
from ...utils import TensorList, as_tensorlist, unpack_states
|
|
9
|
-
from ...
|
|
10
|
-
from ..
|
|
9
|
+
from ...linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..opt_utils import initial_step_size
|
|
11
11
|
from .damping import DampingStrategyType, apply_damping
|
|
12
12
|
|
|
13
13
|
|
|
@@ -154,7 +154,7 @@ class LBFGSLinearOperator(LinearOperator):
|
|
|
154
154
|
return (n, n)
|
|
155
155
|
|
|
156
156
|
|
|
157
|
-
class LBFGS(
|
|
157
|
+
class LBFGS(TensorTransform):
|
|
158
158
|
"""Limited-memory BFGS algorithm. A line search or trust region is recommended.
|
|
159
159
|
|
|
160
160
|
Args:
|
|
@@ -188,7 +188,7 @@ class LBFGS(Transform):
|
|
|
188
188
|
|
|
189
189
|
L-BFGS with line search
|
|
190
190
|
```python
|
|
191
|
-
opt = tz.
|
|
191
|
+
opt = tz.Optimizer(
|
|
192
192
|
model.parameters(),
|
|
193
193
|
tz.m.LBFGS(100),
|
|
194
194
|
tz.m.Backtracking()
|
|
@@ -197,7 +197,7 @@ class LBFGS(Transform):
|
|
|
197
197
|
|
|
198
198
|
L-BFGS with trust region
|
|
199
199
|
```python
|
|
200
|
-
opt = tz.
|
|
200
|
+
opt = tz.Optimizer(
|
|
201
201
|
model.parameters(),
|
|
202
202
|
tz.m.TrustCG(tz.m.LBFGS())
|
|
203
203
|
)
|
|
@@ -226,7 +226,7 @@ class LBFGS(Transform):
|
|
|
226
226
|
sy_tol=sy_tol,
|
|
227
227
|
damping = damping,
|
|
228
228
|
)
|
|
229
|
-
super().__init__(defaults,
|
|
229
|
+
super().__init__(defaults, inner=inner, update_freq=update_freq)
|
|
230
230
|
|
|
231
231
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
232
232
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
@@ -249,7 +249,7 @@ class LBFGS(Transform):
|
|
|
249
249
|
self.global_state.pop('step', None)
|
|
250
250
|
|
|
251
251
|
@torch.no_grad
|
|
252
|
-
def
|
|
252
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
253
253
|
p = as_tensorlist(params)
|
|
254
254
|
g = as_tensorlist(tensors)
|
|
255
255
|
step = self.global_state.get('step', 0)
|
|
@@ -311,14 +311,14 @@ class LBFGS(Transform):
|
|
|
311
311
|
y_history.append(y)
|
|
312
312
|
sy_history.append(sy)
|
|
313
313
|
|
|
314
|
-
def get_H(self,
|
|
314
|
+
def get_H(self, objective=...):
|
|
315
315
|
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
316
316
|
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
317
317
|
sy_history = self.global_state['sy_history']
|
|
318
318
|
return LBFGSLinearOperator(s_history, y_history, sy_history)
|
|
319
319
|
|
|
320
320
|
@torch.no_grad
|
|
321
|
-
def
|
|
321
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
322
322
|
scale_first = self.defaults['scale_first']
|
|
323
323
|
|
|
324
324
|
tensors = as_tensorlist(tensors)
|
|
@@ -4,10 +4,10 @@ from operator import itemgetter
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
7
|
+
from ...core import Chainable, Module, TensorTransform, Objective, step
|
|
8
8
|
from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
|
|
9
|
-
from ...
|
|
10
|
-
from ..
|
|
9
|
+
from ...linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..opt_utils import initial_step_size
|
|
11
11
|
from .damping import DampingStrategyType, apply_damping
|
|
12
12
|
|
|
13
13
|
|
|
@@ -76,7 +76,7 @@ class LSR1LinearOperator(LinearOperator):
|
|
|
76
76
|
return (n, n)
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
class LSR1(
|
|
79
|
+
class LSR1(TensorTransform):
|
|
80
80
|
"""Limited-memory SR1 algorithm. A line search or trust region is recommended.
|
|
81
81
|
|
|
82
82
|
Args:
|
|
@@ -110,7 +110,7 @@ class LSR1(Transform):
|
|
|
110
110
|
|
|
111
111
|
L-SR1 with line search
|
|
112
112
|
```python
|
|
113
|
-
opt = tz.
|
|
113
|
+
opt = tz.Optimizer(
|
|
114
114
|
model.parameters(),
|
|
115
115
|
tz.m.SR1(),
|
|
116
116
|
tz.m.StrongWolfe(c2=0.1, fallback=True)
|
|
@@ -119,7 +119,7 @@ class LSR1(Transform):
|
|
|
119
119
|
|
|
120
120
|
L-SR1 with trust region
|
|
121
121
|
```python
|
|
122
|
-
opt = tz.
|
|
122
|
+
opt = tz.Optimizer(
|
|
123
123
|
model.parameters(),
|
|
124
124
|
tz.m.TrustCG(tz.m.LSR1())
|
|
125
125
|
)
|
|
@@ -146,7 +146,7 @@ class LSR1(Transform):
|
|
|
146
146
|
gtol_restart=gtol_restart,
|
|
147
147
|
damping = damping,
|
|
148
148
|
)
|
|
149
|
-
super().__init__(defaults,
|
|
149
|
+
super().__init__(defaults, inner=inner, update_freq=update_freq)
|
|
150
150
|
|
|
151
151
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
152
152
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
@@ -167,7 +167,7 @@ class LSR1(Transform):
|
|
|
167
167
|
self.global_state.pop('step', None)
|
|
168
168
|
|
|
169
169
|
@torch.no_grad
|
|
170
|
-
def
|
|
170
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
171
171
|
p = as_tensorlist(params)
|
|
172
172
|
g = as_tensorlist(tensors)
|
|
173
173
|
step = self.global_state.get('step', 0)
|
|
@@ -225,13 +225,13 @@ class LSR1(Transform):
|
|
|
225
225
|
s_history.append(s)
|
|
226
226
|
y_history.append(y)
|
|
227
227
|
|
|
228
|
-
def get_H(self,
|
|
228
|
+
def get_H(self, objective=...):
|
|
229
229
|
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
230
230
|
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
231
231
|
return LSR1LinearOperator(s_history, y_history)
|
|
232
232
|
|
|
233
233
|
@torch.no_grad
|
|
234
|
-
def
|
|
234
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
235
235
|
scale_first = self.defaults['scale_first']
|
|
236
236
|
|
|
237
237
|
tensors = as_tensorlist(tensors)
|