torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/utility.py
CHANGED
|
@@ -2,78 +2,78 @@ from collections import deque
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module,
|
|
5
|
+
from ...core import Module, Transform
|
|
6
6
|
from ...utils.tensorlist import Distributions, TensorList
|
|
7
|
-
from ...
|
|
7
|
+
from ...linalg.linear_operator import ScaledIdentity
|
|
8
8
|
|
|
9
9
|
class Clone(Module):
|
|
10
10
|
"""Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
|
|
11
11
|
def __init__(self):
|
|
12
12
|
super().__init__({})
|
|
13
13
|
@torch.no_grad
|
|
14
|
-
def
|
|
15
|
-
|
|
16
|
-
return
|
|
14
|
+
def apply(self, objective):
|
|
15
|
+
objective.updates = [u.clone() for u in objective.get_updates()]
|
|
16
|
+
return objective
|
|
17
17
|
|
|
18
18
|
class Grad(Module):
|
|
19
19
|
"""Outputs the gradient"""
|
|
20
20
|
def __init__(self):
|
|
21
21
|
super().__init__({})
|
|
22
22
|
@torch.no_grad
|
|
23
|
-
def
|
|
24
|
-
|
|
25
|
-
return
|
|
23
|
+
def apply(self, objective):
|
|
24
|
+
objective.updates = [g.clone() for g in objective.get_grads()]
|
|
25
|
+
return objective
|
|
26
26
|
|
|
27
27
|
class Params(Module):
|
|
28
28
|
"""Outputs parameters"""
|
|
29
29
|
def __init__(self):
|
|
30
30
|
super().__init__({})
|
|
31
31
|
@torch.no_grad
|
|
32
|
-
def
|
|
33
|
-
|
|
34
|
-
return
|
|
32
|
+
def apply(self, objective):
|
|
33
|
+
objective.updates = [p.clone() for p in objective.params]
|
|
34
|
+
return objective
|
|
35
35
|
|
|
36
36
|
class Zeros(Module):
|
|
37
37
|
"""Outputs zeros"""
|
|
38
38
|
def __init__(self):
|
|
39
39
|
super().__init__({})
|
|
40
40
|
@torch.no_grad
|
|
41
|
-
def
|
|
42
|
-
|
|
43
|
-
return
|
|
41
|
+
def apply(self, objective):
|
|
42
|
+
objective.updates = [torch.zeros_like(p) for p in objective.params]
|
|
43
|
+
return objective
|
|
44
44
|
|
|
45
45
|
class Ones(Module):
|
|
46
46
|
"""Outputs ones"""
|
|
47
47
|
def __init__(self):
|
|
48
48
|
super().__init__({})
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
51
|
-
|
|
52
|
-
return
|
|
50
|
+
def apply(self, objective):
|
|
51
|
+
objective.updates = [torch.ones_like(p) for p in objective.params]
|
|
52
|
+
return objective
|
|
53
53
|
|
|
54
54
|
class Fill(Module):
|
|
55
|
-
"""Outputs tensors filled with
|
|
55
|
+
"""Outputs tensors filled with ``value``"""
|
|
56
56
|
def __init__(self, value: float):
|
|
57
57
|
defaults = dict(value=value)
|
|
58
58
|
super().__init__(defaults)
|
|
59
59
|
|
|
60
60
|
@torch.no_grad
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
return
|
|
61
|
+
def apply(self, objective):
|
|
62
|
+
objective.updates = [torch.full_like(p, self.settings[p]['value']) for p in objective.params]
|
|
63
|
+
return objective
|
|
64
64
|
|
|
65
65
|
class RandomSample(Module):
|
|
66
|
-
"""Outputs tensors filled with random numbers from distribution depending on value of
|
|
66
|
+
"""Outputs tensors filled with random numbers from distribution depending on value of ``distribution``."""
|
|
67
67
|
def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
|
|
68
68
|
defaults = dict(distribution=distribution, variance=variance)
|
|
69
69
|
super().__init__(defaults)
|
|
70
70
|
|
|
71
71
|
@torch.no_grad
|
|
72
|
-
def
|
|
72
|
+
def apply(self, objective):
|
|
73
73
|
distribution = self.defaults['distribution']
|
|
74
|
-
variance = self.get_settings(
|
|
75
|
-
|
|
76
|
-
return
|
|
74
|
+
variance = self.get_settings(objective.params, 'variance')
|
|
75
|
+
objective.updates = TensorList(objective.params).sample_like(distribution=distribution, variance=variance)
|
|
76
|
+
return objective
|
|
77
77
|
|
|
78
78
|
class Randn(Module):
|
|
79
79
|
"""Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
|
|
@@ -81,43 +81,44 @@ class Randn(Module):
|
|
|
81
81
|
super().__init__({})
|
|
82
82
|
|
|
83
83
|
@torch.no_grad
|
|
84
|
-
def
|
|
85
|
-
|
|
86
|
-
return
|
|
84
|
+
def apply(self, objective):
|
|
85
|
+
objective.updates = [torch.randn_like(p) for p in objective.params]
|
|
86
|
+
return objective
|
|
87
87
|
|
|
88
88
|
class Uniform(Module):
|
|
89
|
-
"""Outputs tensors filled with random numbers from uniform distribution between
|
|
89
|
+
"""Outputs tensors filled with random numbers from uniform distribution between ``low`` and ``high``."""
|
|
90
90
|
def __init__(self, low: float, high: float):
|
|
91
91
|
defaults = dict(low=low, high=high)
|
|
92
92
|
super().__init__(defaults)
|
|
93
93
|
|
|
94
94
|
@torch.no_grad
|
|
95
|
-
def
|
|
96
|
-
low,high = self.get_settings(
|
|
97
|
-
|
|
98
|
-
return
|
|
95
|
+
def apply(self, objective):
|
|
96
|
+
low,high = self.get_settings(objective.params, 'low','high')
|
|
97
|
+
objective.updates = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(objective.params, low, high)]
|
|
98
|
+
return objective
|
|
99
99
|
|
|
100
100
|
class GradToNone(Module):
|
|
101
|
-
"""Sets
|
|
101
|
+
"""Sets ``grad`` attribute to None on ``objective``."""
|
|
102
102
|
def __init__(self): super().__init__()
|
|
103
|
-
def
|
|
104
|
-
|
|
105
|
-
return
|
|
103
|
+
def apply(self, objective):
|
|
104
|
+
objective.grads = None
|
|
105
|
+
return objective
|
|
106
106
|
|
|
107
107
|
class UpdateToNone(Module):
|
|
108
|
-
"""Sets
|
|
108
|
+
"""Sets ``update`` attribute to None on ``var``."""
|
|
109
109
|
def __init__(self): super().__init__()
|
|
110
|
-
def
|
|
111
|
-
|
|
112
|
-
return
|
|
110
|
+
def apply(self, objective):
|
|
111
|
+
objective.updates = None
|
|
112
|
+
return objective
|
|
113
113
|
|
|
114
114
|
class Identity(Module):
|
|
115
115
|
"""Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
|
|
116
116
|
def __init__(self, *args, **kwargs): super().__init__()
|
|
117
|
-
def
|
|
118
|
-
def
|
|
119
|
-
|
|
120
|
-
|
|
117
|
+
def update(self, objective): pass
|
|
118
|
+
def apply(self, objective): return objective
|
|
119
|
+
def get_H(self, objective):
|
|
120
|
+
n = sum(p.numel() for p in objective.params)
|
|
121
|
+
p = objective.params[0]
|
|
121
122
|
return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
|
|
122
123
|
|
|
123
124
|
Noop = Identity
|
|
@@ -8,7 +8,7 @@ from typing import Any, Literal
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from ...core import Chainable, Module,
|
|
11
|
+
from ...core import Chainable, Module, Objective
|
|
12
12
|
from ...utils import set_storage_, vec_to_tensors
|
|
13
13
|
|
|
14
14
|
|
|
@@ -80,7 +80,7 @@ class _FakeProjectedClosure:
|
|
|
80
80
|
class ProjectionBase(Module, ABC):
|
|
81
81
|
"""
|
|
82
82
|
Base class for projections.
|
|
83
|
-
This is an abstract class, to use it, subclass it and override
|
|
83
|
+
This is an abstract class, to use it, subclass it and override ``project`` and ``unproject``.
|
|
84
84
|
|
|
85
85
|
Args:
|
|
86
86
|
modules (Chainable): modules that will be applied in the projected domain.
|
|
@@ -150,8 +150,8 @@ class ProjectionBase(Module, ABC):
|
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
152
|
@torch.no_grad
|
|
153
|
-
def
|
|
154
|
-
params =
|
|
153
|
+
def apply(self, objective: Objective):
|
|
154
|
+
params = objective.params
|
|
155
155
|
settings = [self.settings[p] for p in params]
|
|
156
156
|
|
|
157
157
|
def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
|
|
@@ -159,16 +159,16 @@ class ProjectionBase(Module, ABC):
|
|
|
159
159
|
return list(self.project(
|
|
160
160
|
tensors=tensors,
|
|
161
161
|
params=params,
|
|
162
|
-
grads=
|
|
163
|
-
loss=
|
|
162
|
+
grads=objective.grads,
|
|
163
|
+
loss=objective.loss,
|
|
164
164
|
states=states,
|
|
165
165
|
settings=settings,
|
|
166
166
|
current=current,
|
|
167
167
|
))
|
|
168
168
|
|
|
169
|
-
|
|
169
|
+
projected_obj = objective.clone(clone_updates=False, parent=objective)
|
|
170
170
|
|
|
171
|
-
closure =
|
|
171
|
+
closure = objective.closure
|
|
172
172
|
|
|
173
173
|
# if this is True, update and grad were projected simultaneously under current="grads"
|
|
174
174
|
# so update will have to be unprojected with current="grads"
|
|
@@ -179,9 +179,9 @@ class ProjectionBase(Module, ABC):
|
|
|
179
179
|
# but if it has already been computed, it should be projected
|
|
180
180
|
if self._project_params and closure is not None:
|
|
181
181
|
|
|
182
|
-
if self._project_update and
|
|
182
|
+
if self._project_update and objective.updates is not None:
|
|
183
183
|
# project update only if it already exists
|
|
184
|
-
|
|
184
|
+
projected_obj.updates = _project(objective.updates, current='update')
|
|
185
185
|
|
|
186
186
|
else:
|
|
187
187
|
# update will be set to gradients on var.get_grad()
|
|
@@ -189,43 +189,43 @@ class ProjectionBase(Module, ABC):
|
|
|
189
189
|
update_is_grad = True
|
|
190
190
|
|
|
191
191
|
# project grad only if it already exists
|
|
192
|
-
if self._project_grad and
|
|
193
|
-
|
|
192
|
+
if self._project_grad and objective.grads is not None:
|
|
193
|
+
projected_obj.grads = _project(objective.grads, current='grads')
|
|
194
194
|
|
|
195
195
|
# otherwise update/grad needs to be calculated and projected here
|
|
196
196
|
else:
|
|
197
197
|
if self._project_update:
|
|
198
|
-
if
|
|
198
|
+
if objective.updates is None:
|
|
199
199
|
# update is None, meaning it will be set to `grad`.
|
|
200
200
|
# we can project grad and use it for update
|
|
201
|
-
grad =
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
del
|
|
201
|
+
grad = objective.get_grads()
|
|
202
|
+
projected_obj.grads = _project(grad, current='grads')
|
|
203
|
+
projected_obj.updates = [g.clone() for g in projected_obj.grads]
|
|
204
|
+
del objective.updates
|
|
205
205
|
update_is_grad = True
|
|
206
206
|
|
|
207
207
|
else:
|
|
208
208
|
# update exists so it needs to be projected
|
|
209
|
-
update =
|
|
210
|
-
|
|
211
|
-
del update,
|
|
209
|
+
update = objective.get_updates()
|
|
210
|
+
projected_obj.updates = _project(update, current='update')
|
|
211
|
+
del update, objective.updates
|
|
212
212
|
|
|
213
|
-
if self._project_grad and
|
|
213
|
+
if self._project_grad and projected_obj.grads is None:
|
|
214
214
|
# projected_vars.grad may have been projected simultaneously with update
|
|
215
215
|
# but if that didn't happen, it is projected here
|
|
216
|
-
grad =
|
|
217
|
-
|
|
216
|
+
grad = objective.get_grads()
|
|
217
|
+
projected_obj.grads = _project(grad, current='grads')
|
|
218
218
|
|
|
219
219
|
|
|
220
220
|
original_params = None
|
|
221
221
|
if self._project_params:
|
|
222
|
-
original_params = [p.clone() for p in
|
|
223
|
-
projected_params = _project(
|
|
222
|
+
original_params = [p.clone() for p in objective.params]
|
|
223
|
+
projected_params = _project(objective.params, current='params')
|
|
224
224
|
|
|
225
225
|
else:
|
|
226
226
|
# make fake params for correct shapes and state storage
|
|
227
227
|
# they reuse update or grad storage for memory efficiency
|
|
228
|
-
projected_params =
|
|
228
|
+
projected_params = projected_obj.updates if projected_obj.updates is not None else projected_obj.grads
|
|
229
229
|
assert projected_params is not None
|
|
230
230
|
|
|
231
231
|
if self._projected_params is None:
|
|
@@ -245,8 +245,8 @@ class ProjectionBase(Module, ABC):
|
|
|
245
245
|
return list(self.unproject(
|
|
246
246
|
projected_tensors=projected_tensors,
|
|
247
247
|
params=params,
|
|
248
|
-
grads=
|
|
249
|
-
loss=
|
|
248
|
+
grads=objective.grads,
|
|
249
|
+
loss=objective.loss,
|
|
250
250
|
states=states,
|
|
251
251
|
settings=settings,
|
|
252
252
|
current=current,
|
|
@@ -254,19 +254,19 @@ class ProjectionBase(Module, ABC):
|
|
|
254
254
|
|
|
255
255
|
# project closure
|
|
256
256
|
if self._project_params:
|
|
257
|
-
|
|
257
|
+
projected_obj.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
|
|
258
258
|
params=params, projected_params=projected_params)
|
|
259
259
|
|
|
260
260
|
elif closure is not None:
|
|
261
|
-
|
|
261
|
+
projected_obj.closure = _FakeProjectedClosure(closure, project_fn=_project,
|
|
262
262
|
params=params, fake_params=projected_params)
|
|
263
263
|
|
|
264
264
|
else:
|
|
265
|
-
|
|
265
|
+
projected_obj.closure = None
|
|
266
266
|
|
|
267
267
|
# ----------------------------------- step ----------------------------------- #
|
|
268
|
-
|
|
269
|
-
|
|
268
|
+
projected_obj.params = projected_params
|
|
269
|
+
projected_obj = self.children['modules'].apply(projected_obj)
|
|
270
270
|
|
|
271
271
|
# empty fake params storage
|
|
272
272
|
# this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
|
|
@@ -275,24 +275,24 @@ class ProjectionBase(Module, ABC):
|
|
|
275
275
|
set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
|
|
276
276
|
|
|
277
277
|
# --------------------------------- unproject -------------------------------- #
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
278
|
+
unprojected_obj = projected_obj.clone(clone_updates=False)
|
|
279
|
+
unprojected_obj.closure = objective.closure
|
|
280
|
+
unprojected_obj.params = objective.params
|
|
281
|
+
unprojected_obj.grads = objective.grads # this may also be set by projected_var since it has var as parent
|
|
282
282
|
|
|
283
283
|
if self._project_update:
|
|
284
|
-
assert
|
|
285
|
-
|
|
286
|
-
del
|
|
284
|
+
assert projected_obj.updates is not None
|
|
285
|
+
unprojected_obj.updates = _unproject(projected_obj.updates, current='grads' if update_is_grad else 'update')
|
|
286
|
+
del projected_obj.updates
|
|
287
287
|
|
|
288
|
-
del
|
|
288
|
+
del projected_obj
|
|
289
289
|
|
|
290
290
|
# original params are stored if params are projected
|
|
291
291
|
if original_params is not None:
|
|
292
|
-
for p, o in zip(
|
|
292
|
+
for p, o in zip(unprojected_obj.params, original_params):
|
|
293
293
|
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
294
294
|
|
|
295
|
-
return
|
|
295
|
+
return unprojected_obj
|
|
296
296
|
|
|
297
297
|
|
|
298
298
|
|
|
@@ -4,7 +4,7 @@ from typing import Literal, Protocol, overload
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...utils import TensorList
|
|
7
|
-
from ...
|
|
7
|
+
from ...linalg.linear_operator import DenseInverse, LinearOperator
|
|
8
8
|
from ..functional import safe_clip
|
|
9
9
|
|
|
10
10
|
|
|
@@ -4,9 +4,9 @@ from typing import overload
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable,
|
|
7
|
+
from ...core import Chainable, TensorTransform
|
|
8
8
|
from ...utils import TensorList, as_tensorlist, unpack_states
|
|
9
|
-
from ...
|
|
9
|
+
from ...linalg.linear_operator import LinearOperator
|
|
10
10
|
from ..functional import initial_step_size
|
|
11
11
|
from .damping import DampingStrategyType, apply_damping
|
|
12
12
|
|
|
@@ -154,7 +154,7 @@ class LBFGSLinearOperator(LinearOperator):
|
|
|
154
154
|
return (n, n)
|
|
155
155
|
|
|
156
156
|
|
|
157
|
-
class LBFGS(
|
|
157
|
+
class LBFGS(TensorTransform):
|
|
158
158
|
"""Limited-memory BFGS algorithm. A line search or trust region is recommended.
|
|
159
159
|
|
|
160
160
|
Args:
|
|
@@ -226,7 +226,7 @@ class LBFGS(Transform):
|
|
|
226
226
|
sy_tol=sy_tol,
|
|
227
227
|
damping = damping,
|
|
228
228
|
)
|
|
229
|
-
super().__init__(defaults,
|
|
229
|
+
super().__init__(defaults, inner=inner, update_freq=update_freq)
|
|
230
230
|
|
|
231
231
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
232
232
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
@@ -249,7 +249,7 @@ class LBFGS(Transform):
|
|
|
249
249
|
self.global_state.pop('step', None)
|
|
250
250
|
|
|
251
251
|
@torch.no_grad
|
|
252
|
-
def
|
|
252
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
253
253
|
p = as_tensorlist(params)
|
|
254
254
|
g = as_tensorlist(tensors)
|
|
255
255
|
step = self.global_state.get('step', 0)
|
|
@@ -311,14 +311,14 @@ class LBFGS(Transform):
|
|
|
311
311
|
y_history.append(y)
|
|
312
312
|
sy_history.append(sy)
|
|
313
313
|
|
|
314
|
-
def get_H(self,
|
|
314
|
+
def get_H(self, objective=...):
|
|
315
315
|
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
316
316
|
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
317
317
|
sy_history = self.global_state['sy_history']
|
|
318
318
|
return LBFGSLinearOperator(s_history, y_history, sy_history)
|
|
319
319
|
|
|
320
320
|
@torch.no_grad
|
|
321
|
-
def
|
|
321
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
322
322
|
scale_first = self.defaults['scale_first']
|
|
323
323
|
|
|
324
324
|
tensors = as_tensorlist(tensors)
|
|
@@ -4,9 +4,9 @@ from operator import itemgetter
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
7
|
+
from ...core import Chainable, Module, TensorTransform, Objective, step
|
|
8
8
|
from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
|
|
9
|
-
from ...
|
|
9
|
+
from ...linalg.linear_operator import LinearOperator
|
|
10
10
|
from ..functional import initial_step_size
|
|
11
11
|
from .damping import DampingStrategyType, apply_damping
|
|
12
12
|
|
|
@@ -76,7 +76,7 @@ class LSR1LinearOperator(LinearOperator):
|
|
|
76
76
|
return (n, n)
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
class LSR1(
|
|
79
|
+
class LSR1(TensorTransform):
|
|
80
80
|
"""Limited-memory SR1 algorithm. A line search or trust region is recommended.
|
|
81
81
|
|
|
82
82
|
Args:
|
|
@@ -146,7 +146,7 @@ class LSR1(Transform):
|
|
|
146
146
|
gtol_restart=gtol_restart,
|
|
147
147
|
damping = damping,
|
|
148
148
|
)
|
|
149
|
-
super().__init__(defaults,
|
|
149
|
+
super().__init__(defaults, inner=inner, update_freq=update_freq)
|
|
150
150
|
|
|
151
151
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
152
152
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
@@ -167,7 +167,7 @@ class LSR1(Transform):
|
|
|
167
167
|
self.global_state.pop('step', None)
|
|
168
168
|
|
|
169
169
|
@torch.no_grad
|
|
170
|
-
def
|
|
170
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
171
171
|
p = as_tensorlist(params)
|
|
172
172
|
g = as_tensorlist(tensors)
|
|
173
173
|
step = self.global_state.get('step', 0)
|
|
@@ -225,13 +225,13 @@ class LSR1(Transform):
|
|
|
225
225
|
s_history.append(s)
|
|
226
226
|
y_history.append(y)
|
|
227
227
|
|
|
228
|
-
def get_H(self,
|
|
228
|
+
def get_H(self, objective=...):
|
|
229
229
|
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
230
230
|
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
231
231
|
return LSR1LinearOperator(s_history, y_history)
|
|
232
232
|
|
|
233
233
|
@torch.no_grad
|
|
234
|
-
def
|
|
234
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
235
235
|
scale_first = self.defaults['scale_first']
|
|
236
236
|
|
|
237
237
|
tensors = as_tensorlist(tensors)
|
|
@@ -5,9 +5,9 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module,
|
|
8
|
+
from ...core import Chainable, Module, TensorTransform, Transform
|
|
9
9
|
from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
|
|
10
|
-
from ...
|
|
10
|
+
from ...linalg import linear_operator
|
|
11
11
|
from ..functional import initial_step_size, safe_clip
|
|
12
12
|
|
|
13
13
|
|
|
@@ -17,7 +17,7 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
|
17
17
|
elif state[key].shape != value.shape: state[key] = value
|
|
18
18
|
else: state[key].lerp_(value, 1-beta)
|
|
19
19
|
|
|
20
|
-
class HessianUpdateStrategy(
|
|
20
|
+
class HessianUpdateStrategy(TensorTransform, ABC):
|
|
21
21
|
"""Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
|
|
22
22
|
|
|
23
23
|
This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
|
|
@@ -157,7 +157,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
157
157
|
else: P *= init_scale
|
|
158
158
|
|
|
159
159
|
@torch.no_grad
|
|
160
|
-
def
|
|
160
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
161
161
|
p = param.view(-1); g = tensor.view(-1)
|
|
162
162
|
inverse = setting['inverse']
|
|
163
163
|
M_key = 'H' if inverse else 'B'
|
|
@@ -223,7 +223,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
223
223
|
state['f_prev'] = loss
|
|
224
224
|
|
|
225
225
|
@torch.no_grad
|
|
226
|
-
def
|
|
226
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
227
227
|
step = state['step']
|
|
228
228
|
|
|
229
229
|
if setting['scale_first'] and step == 1:
|
|
@@ -250,8 +250,8 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
250
250
|
self.global_state.clear()
|
|
251
251
|
return tensor.mul_(initial_step_size(tensor))
|
|
252
252
|
|
|
253
|
-
def get_H(self,
|
|
254
|
-
param =
|
|
253
|
+
def get_H(self, objective):
|
|
254
|
+
param = objective.params[0]
|
|
255
255
|
state = self.state[param]
|
|
256
256
|
settings = self.settings[param]
|
|
257
257
|
if "B" in state:
|
|
@@ -1005,7 +1005,7 @@ def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
|
1005
1005
|
return g - (y * (s.dot(g) / sy))
|
|
1006
1006
|
|
|
1007
1007
|
|
|
1008
|
-
class GradientCorrection(
|
|
1008
|
+
class GradientCorrection(TensorTransform):
|
|
1009
1009
|
"""
|
|
1010
1010
|
Estimates gradient at minima along search direction assuming function is quadratic.
|
|
1011
1011
|
|
|
@@ -1027,9 +1027,9 @@ class GradientCorrection(Transform):
|
|
|
1027
1027
|
|
|
1028
1028
|
"""
|
|
1029
1029
|
def __init__(self):
|
|
1030
|
-
super().__init__(
|
|
1030
|
+
super().__init__()
|
|
1031
1031
|
|
|
1032
|
-
def
|
|
1032
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
1033
1033
|
if 'p_prev' not in states[0]:
|
|
1034
1034
|
p_prev = unpack_states(states, tensors, 'p_prev', init=params)
|
|
1035
1035
|
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
|