torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .projection import ProjectionBase
|
|
3
|
+
from ...core import Chainable
|
|
4
|
+
|
|
5
|
+
class To(ProjectionBase):
|
|
6
|
+
"""Cast modules to specified device and dtype"""
|
|
7
|
+
def __init__(self, modules: Chainable, dtype: torch.dtype | None, device:torch.types.Device | None = None):
|
|
8
|
+
defaults = dict(dtype=dtype, device=device)
|
|
9
|
+
super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=defaults)
|
|
10
|
+
|
|
11
|
+
@torch.no_grad
|
|
12
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
13
|
+
casted = []
|
|
14
|
+
for tensor, state, setting in zip(tensors,states, settings):
|
|
15
|
+
state['dtype'] = tensor.dtype
|
|
16
|
+
state['device'] = tensor.device
|
|
17
|
+
tensor = tensor.to(dtype=setting['dtype'], device=setting['device'])
|
|
18
|
+
casted.append(tensor)
|
|
19
|
+
return casted
|
|
20
|
+
|
|
21
|
+
@torch.no_grad
|
|
22
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
23
|
+
uncasted = []
|
|
24
|
+
for tensor, state in zip(projected_tensors, states):
|
|
25
|
+
tensor = tensor.to(dtype=state['dtype'], device=state['device'])
|
|
26
|
+
uncasted.append(tensor)
|
|
27
|
+
return uncasted
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ViewAsReal(ProjectionBase):
|
|
31
|
+
"""View complex tensors as real tensors. Doesn't affect tensors that are already."""
|
|
32
|
+
def __init__(self, modules: Chainable):
|
|
33
|
+
super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=None)
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
37
|
+
views = []
|
|
38
|
+
for tensor, state in zip(tensors,states):
|
|
39
|
+
is_complex = torch.is_complex(tensor)
|
|
40
|
+
state['is_complex'] = is_complex
|
|
41
|
+
if is_complex: tensor = torch.view_as_real(tensor)
|
|
42
|
+
views.append(tensor)
|
|
43
|
+
return views
|
|
44
|
+
|
|
45
|
+
@torch.no_grad
|
|
46
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
47
|
+
un_views = []
|
|
48
|
+
for tensor, state in zip(projected_tensors, states):
|
|
49
|
+
if state['is_complex']: tensor = torch.view_as_complex(tensor)
|
|
50
|
+
un_views.append(tensor)
|
|
51
|
+
return un_views
|
|
@@ -1,29 +1,35 @@
|
|
|
1
1
|
import math
|
|
2
|
-
|
|
2
|
+
import warnings
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections
|
|
4
|
+
from collections import ChainMap, defaultdict
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
6
|
+
from functools import partial
|
|
5
7
|
from typing import Any, Literal
|
|
6
|
-
|
|
8
|
+
|
|
7
9
|
import torch
|
|
8
10
|
|
|
9
11
|
from ...core import Chainable, Module, Var
|
|
10
|
-
from ...utils import vec_to_tensors
|
|
12
|
+
from ...utils import set_storage_, vec_to_tensors
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def _make_projected_closure(closure,
|
|
15
|
+
def _make_projected_closure(closure, project_fn, unproject_fn,
|
|
14
16
|
params: list[torch.Tensor], projected_params: list[torch.Tensor]):
|
|
15
|
-
|
|
16
17
|
def projected_closure(backward=True):
|
|
17
|
-
|
|
18
|
+
# unproject projected params
|
|
19
|
+
unprojected_params = unproject_fn(projected_tensors=projected_params, current='params')
|
|
18
20
|
|
|
21
|
+
# set actual model parameters to suggested parameters
|
|
19
22
|
with torch.no_grad():
|
|
20
23
|
for p, new_p in zip(params, unprojected_params):
|
|
21
24
|
p.set_(new_p) # pyright: ignore[reportArgumentType]
|
|
22
25
|
|
|
26
|
+
# evaluate closure with suggested parameters
|
|
23
27
|
if backward:
|
|
24
28
|
loss = closure()
|
|
25
29
|
grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
26
|
-
|
|
30
|
+
|
|
31
|
+
# project gradients on backward and set to projected parameter .grad attributes
|
|
32
|
+
projected_grads = project_fn(grads, current='grads')
|
|
27
33
|
for p, g in zip(projected_params, projected_grads):
|
|
28
34
|
p.grad = g
|
|
29
35
|
|
|
@@ -34,27 +40,44 @@ def _make_projected_closure(closure, var: Var, projection: "Projection",
|
|
|
34
40
|
|
|
35
41
|
return projected_closure
|
|
36
42
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
self
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
43
|
+
class _FakeProjectedClosure:
|
|
44
|
+
"""This is used when project_params is False. Then the closure is meant to only be used to evaluate the initial gradient.
|
|
45
|
+
It should just evaluate original closure, project the gradients, and set them to fake params.
|
|
46
|
+
|
|
47
|
+
I made it into a class so that it can know and raise when it evaluates closure more than once.
|
|
48
|
+
"""
|
|
49
|
+
__slots__ = ('closure', 'project_fn', 'params', 'fake_params', 'evaluated')
|
|
50
|
+
def __init__(self, closure, project_fn, params: list[torch.Tensor], fake_params: list[torch.Tensor]):
|
|
51
|
+
self.closure = closure
|
|
52
|
+
self.project_fn = project_fn
|
|
53
|
+
self.params = params
|
|
54
|
+
self.fake_params = fake_params
|
|
55
|
+
self.evaluated = False
|
|
56
|
+
|
|
57
|
+
def __call__(self, backward: bool = True):
|
|
58
|
+
if self.evaluated:
|
|
59
|
+
raise RuntimeError("set project_params to True if projected modules require closure.")
|
|
60
|
+
self.evaluated = True
|
|
61
|
+
|
|
62
|
+
# evaluate closure with suggested parameters
|
|
63
|
+
if backward:
|
|
64
|
+
|
|
65
|
+
loss = self.closure()
|
|
66
|
+
grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
67
|
+
|
|
68
|
+
# project gradients on backward and set to projected parameter .grad attributes
|
|
69
|
+
projected_grads = self.project_fn(grads, current='grads')
|
|
70
|
+
for p, g in zip(self.fake_params, projected_grads):
|
|
71
|
+
p.grad = g
|
|
72
|
+
|
|
73
|
+
else:
|
|
74
|
+
loss = self.closure(False)
|
|
75
|
+
|
|
76
|
+
return loss
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ProjectionBase(Module, ABC):
|
|
58
81
|
"""
|
|
59
82
|
Base class for projections.
|
|
60
83
|
This is an abstract class, to use it, subclass it and override `project` and `unproject`.
|
|
@@ -84,52 +107,120 @@ class Projection(Module, ABC):
|
|
|
84
107
|
self._project_grad = project_grad
|
|
85
108
|
self._projected_params = None
|
|
86
109
|
|
|
110
|
+
self._states: dict[str, list[dict[str, Any]]] = {}
|
|
111
|
+
"""per-parameter states for each projection target"""
|
|
112
|
+
|
|
87
113
|
@abstractmethod
|
|
88
|
-
def project(
|
|
114
|
+
def project(
|
|
115
|
+
self,
|
|
116
|
+
tensors: list[torch.Tensor],
|
|
117
|
+
params: list[torch.Tensor],
|
|
118
|
+
grads: list[torch.Tensor] | None,
|
|
119
|
+
loss: torch.Tensor | None,
|
|
120
|
+
states: list[dict[str, Any]],
|
|
121
|
+
settings: list[ChainMap[str, Any]],
|
|
122
|
+
current: str,
|
|
123
|
+
) -> Iterable[torch.Tensor]:
|
|
89
124
|
"""projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
|
|
90
125
|
|
|
91
126
|
@abstractmethod
|
|
92
|
-
def unproject(
|
|
93
|
-
|
|
127
|
+
def unproject(
|
|
128
|
+
self,
|
|
129
|
+
projected_tensors: list[torch.Tensor],
|
|
130
|
+
params: list[torch.Tensor],
|
|
131
|
+
grads: list[torch.Tensor] | None,
|
|
132
|
+
loss: torch.Tensor | None,
|
|
133
|
+
states: list[dict[str, Any]],
|
|
134
|
+
settings: list[ChainMap[str, Any]],
|
|
135
|
+
current: str,
|
|
136
|
+
) -> Iterable[torch.Tensor]:
|
|
137
|
+
"""unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
projected_tensors (list[torch.Tensor]): projected tensors to unproject.
|
|
141
|
+
params (list[torch.Tensor]): original, unprojected parameters.
|
|
142
|
+
grads (list[torch.Tensor] | None): original, unprojected gradients
|
|
143
|
+
loss (torch.Tensor | None): loss at initial point.
|
|
144
|
+
states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
|
|
145
|
+
settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
|
|
146
|
+
current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Iterable[torch.Tensor]: unprojected tensors of the same shape as params
|
|
150
|
+
"""
|
|
94
151
|
|
|
95
152
|
@torch.no_grad
|
|
96
153
|
def step(self, var: Var):
|
|
97
|
-
|
|
154
|
+
params = var.params
|
|
155
|
+
settings = [self.settings[p] for p in params]
|
|
156
|
+
|
|
157
|
+
def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
|
|
158
|
+
states = self._states.setdefault(current, [{} for _ in params])
|
|
159
|
+
return list(self.project(
|
|
160
|
+
tensors=tensors,
|
|
161
|
+
params=params,
|
|
162
|
+
grads=var.grad,
|
|
163
|
+
loss=var.loss,
|
|
164
|
+
states=states,
|
|
165
|
+
settings=settings,
|
|
166
|
+
current=current,
|
|
167
|
+
))
|
|
168
|
+
|
|
169
|
+
projected_var = var.clone(clone_update=False, parent=var)
|
|
170
|
+
|
|
171
|
+
closure = var.closure
|
|
172
|
+
|
|
173
|
+
# if this is True, update and grad were projected simultaneously under current="grads"
|
|
174
|
+
# so update will have to be unprojected with current="grads"
|
|
98
175
|
update_is_grad = False
|
|
99
176
|
|
|
100
|
-
# closure
|
|
101
|
-
|
|
102
|
-
|
|
177
|
+
# if closure is provided and project_params=True, make new closure that evaluates projected params
|
|
178
|
+
# that also means projected modules can evaluate grad/update at will, it shouldn't be computed here
|
|
179
|
+
# but if it has already been computed, it should be projected
|
|
180
|
+
if self._project_params and closure is not None:
|
|
181
|
+
|
|
182
|
+
if self._project_update and var.update is not None:
|
|
183
|
+
# project update only if it already exists
|
|
184
|
+
projected_var.update = _project(var.update, current='update')
|
|
185
|
+
|
|
103
186
|
else:
|
|
187
|
+
# update will be set to gradients on var.get_grad()
|
|
188
|
+
# therefore projection will happen with current="grads"
|
|
104
189
|
update_is_grad = True
|
|
105
|
-
if self._project_grad and var.grad is not None: projected_var.grad = list(self.project(var.grad, var=var, current='grads'))
|
|
106
190
|
|
|
107
|
-
|
|
191
|
+
# project grad only if it already exists
|
|
192
|
+
if self._project_grad and var.grad is not None:
|
|
193
|
+
projected_var.grad = _project(var.grad, current='grads')
|
|
194
|
+
|
|
195
|
+
# otherwise update/grad needs to be calculated and projected here
|
|
108
196
|
else:
|
|
109
197
|
if self._project_update:
|
|
110
198
|
if var.update is None:
|
|
111
199
|
# update is None, meaning it will be set to `grad`.
|
|
112
200
|
# we can project grad and use it for update
|
|
113
201
|
grad = var.get_grad()
|
|
114
|
-
projected_var.grad =
|
|
115
|
-
|
|
116
|
-
else: projected_var.update = projected_var.grad.copy() # don't clone because grad shouldn't be used
|
|
202
|
+
projected_var.grad = _project(grad, current='grads')
|
|
203
|
+
projected_var.update = [g.clone() for g in projected_var.grad]
|
|
117
204
|
del var.update
|
|
118
205
|
update_is_grad = True
|
|
119
206
|
|
|
120
207
|
else:
|
|
208
|
+
# update exists so it needs to be projected
|
|
121
209
|
update = var.get_update()
|
|
122
|
-
projected_var.update =
|
|
210
|
+
projected_var.update = _project(update, current='update')
|
|
123
211
|
del update, var.update
|
|
124
212
|
|
|
125
213
|
if self._project_grad and projected_var.grad is None:
|
|
214
|
+
# projected_vars.grad may have been projected simultaneously with update
|
|
215
|
+
# but if that didn't happen, it is projected here
|
|
126
216
|
grad = var.get_grad()
|
|
127
|
-
projected_var.grad =
|
|
217
|
+
projected_var.grad = _project(grad, current='grads')
|
|
218
|
+
|
|
128
219
|
|
|
129
220
|
original_params = None
|
|
130
221
|
if self._project_params:
|
|
131
222
|
original_params = [p.clone() for p in var.params]
|
|
132
|
-
projected_params =
|
|
223
|
+
projected_params = _project(var.params, current='params')
|
|
133
224
|
|
|
134
225
|
else:
|
|
135
226
|
# make fake params for correct shapes and state storage
|
|
@@ -146,49 +237,57 @@ class Projection(Module, ABC):
|
|
|
146
237
|
for empty_p, new_p in zip(self._projected_params, projected_params):
|
|
147
238
|
empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]
|
|
148
239
|
|
|
240
|
+
projected_params = self._projected_params
|
|
241
|
+
# projected_settings = [self.settings[p] for p in projected_params]
|
|
242
|
+
|
|
243
|
+
def _unproject(projected_tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
|
|
244
|
+
states = self._states.setdefault(current, [{} for _ in params])
|
|
245
|
+
return list(self.unproject(
|
|
246
|
+
projected_tensors=projected_tensors,
|
|
247
|
+
params=params,
|
|
248
|
+
grads=var.grad,
|
|
249
|
+
loss=var.loss,
|
|
250
|
+
states=states,
|
|
251
|
+
settings=settings,
|
|
252
|
+
current=current,
|
|
253
|
+
))
|
|
254
|
+
|
|
149
255
|
# project closure
|
|
150
256
|
if self._project_params:
|
|
151
|
-
closure =
|
|
152
|
-
|
|
153
|
-
|
|
257
|
+
projected_var.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
|
|
258
|
+
params=params, projected_params=projected_params)
|
|
259
|
+
|
|
260
|
+
elif closure is not None:
|
|
261
|
+
projected_var.closure = _FakeProjectedClosure(closure, project_fn=_project,
|
|
262
|
+
params=params, fake_params=projected_params)
|
|
154
263
|
|
|
155
264
|
else:
|
|
156
265
|
projected_var.closure = None
|
|
157
266
|
|
|
158
|
-
# step
|
|
159
|
-
projected_var.params =
|
|
160
|
-
projected_var.get_grad = partial(
|
|
161
|
-
_projected_get_grad_override,
|
|
162
|
-
projection=self,
|
|
163
|
-
unprojected_var=var,
|
|
164
|
-
self=projected_var,
|
|
165
|
-
)
|
|
267
|
+
# ----------------------------------- step ----------------------------------- #
|
|
268
|
+
projected_var.params = projected_params
|
|
166
269
|
projected_var = self.children['modules'].step(projected_var)
|
|
167
270
|
|
|
168
271
|
# empty fake params storage
|
|
169
272
|
# this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
|
|
170
273
|
if not self._project_params:
|
|
171
274
|
for p in self._projected_params:
|
|
172
|
-
p
|
|
275
|
+
set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
|
|
173
276
|
|
|
174
|
-
# unproject
|
|
277
|
+
# --------------------------------- unproject -------------------------------- #
|
|
175
278
|
unprojected_var = projected_var.clone(clone_update=False)
|
|
176
279
|
unprojected_var.closure = var.closure
|
|
177
280
|
unprojected_var.params = var.params
|
|
178
|
-
unprojected_var.grad = var.grad
|
|
281
|
+
unprojected_var.grad = var.grad # this may also be set by projected_var since it has var as parent
|
|
179
282
|
|
|
180
283
|
if self._project_update:
|
|
181
284
|
assert projected_var.update is not None
|
|
182
|
-
unprojected_var.update =
|
|
285
|
+
unprojected_var.update = _unproject(projected_var.update, current='grads' if update_is_grad else 'update')
|
|
183
286
|
del projected_var.update
|
|
184
287
|
|
|
185
|
-
# unprojecting grad doesn't make sense?
|
|
186
|
-
# if self._project_grad:
|
|
187
|
-
# assert projected_var.grad is not None
|
|
188
|
-
# unprojected_var.grad = list(self.unproject(projected_var.grad, var=var))
|
|
189
|
-
|
|
190
288
|
del projected_var
|
|
191
289
|
|
|
290
|
+
# original params are stored if params are projected
|
|
192
291
|
if original_params is not None:
|
|
193
292
|
for p, o in zip(unprojected_var.params, original_params):
|
|
194
293
|
p.set_(o) # pyright: ignore[reportArgumentType]
|
|
@@ -197,48 +296,43 @@ class Projection(Module, ABC):
|
|
|
197
296
|
|
|
198
297
|
|
|
199
298
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
|
|
211
|
-
|
|
212
|
-
@torch.no_grad
|
|
213
|
-
def unproject(self, tensors, var, current):
|
|
214
|
-
return vec_to_tensors(vec=tensors[0].flip(0), reference=var.params)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
class NoopProjection(Projection):
|
|
218
|
-
"""an example projection which doesn't do anything for testing"""
|
|
219
|
-
|
|
220
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
299
|
+
# basic examples
|
|
300
|
+
class VectorProjection(ProjectionBase):
|
|
301
|
+
"""projection that concatenates all parameters into a vector"""
|
|
302
|
+
def __init__(
|
|
303
|
+
self,
|
|
304
|
+
modules: Chainable,
|
|
305
|
+
project_update=True,
|
|
306
|
+
project_params=True,
|
|
307
|
+
project_grad=True,
|
|
308
|
+
):
|
|
221
309
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
222
310
|
|
|
223
311
|
@torch.no_grad
|
|
224
|
-
def project(self, tensors,
|
|
225
|
-
return tensors
|
|
312
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
313
|
+
return [torch.cat([t.ravel() for t in tensors])]
|
|
226
314
|
|
|
227
315
|
@torch.no_grad
|
|
228
|
-
def unproject(self,
|
|
229
|
-
return
|
|
316
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
317
|
+
return vec_to_tensors(vec=projected_tensors[0], reference=params)
|
|
230
318
|
|
|
231
|
-
class MultipyProjection(Projection):
|
|
232
|
-
"""an example projection which multiplies everything by 2"""
|
|
233
319
|
|
|
234
|
-
|
|
320
|
+
class ScalarProjection(ProjectionBase):
|
|
321
|
+
"""projetion that splits all parameters into individual scalars"""
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
modules: Chainable,
|
|
325
|
+
project_update=True,
|
|
326
|
+
project_params=True,
|
|
327
|
+
project_grad=True,
|
|
328
|
+
):
|
|
235
329
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
236
330
|
|
|
237
331
|
@torch.no_grad
|
|
238
|
-
def project(self, tensors,
|
|
239
|
-
return
|
|
332
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
333
|
+
return [s for t in tensors for s in t.ravel().unbind(0)]
|
|
240
334
|
|
|
241
335
|
@torch.no_grad
|
|
242
|
-
def unproject(self,
|
|
243
|
-
return torch.
|
|
336
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
337
|
+
return vec_to_tensors(vec=torch.stack(projected_tensors), reference=params)
|
|
244
338
|
|
|
@@ -1,28 +1,22 @@
|
|
|
1
|
-
from .
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
LiuStorey,
|
|
9
|
-
PolakRibiere,
|
|
10
|
-
ProjectedGradientMethod,
|
|
1
|
+
from .diagonal_quasi_newton import (
|
|
2
|
+
DNRTR,
|
|
3
|
+
DiagonalBFGS,
|
|
4
|
+
DiagonalQuasiCauchi,
|
|
5
|
+
DiagonalSR1,
|
|
6
|
+
DiagonalWeightedQuasiCauchi,
|
|
7
|
+
NewDQN,
|
|
11
8
|
)
|
|
12
9
|
from .lbfgs import LBFGS
|
|
13
10
|
from .lsr1 import LSR1
|
|
14
|
-
from .olbfgs import OnlineLBFGS
|
|
15
|
-
|
|
16
|
-
# from .experimental import ModularLBFGS
|
|
17
11
|
from .quasi_newton import (
|
|
18
12
|
BFGS,
|
|
19
13
|
DFP,
|
|
14
|
+
ICUM,
|
|
20
15
|
PSB,
|
|
21
16
|
SR1,
|
|
22
17
|
SSVM,
|
|
23
18
|
BroydenBad,
|
|
24
19
|
BroydenGood,
|
|
25
|
-
ColumnUpdatingMethod,
|
|
26
20
|
FletcherVMM,
|
|
27
21
|
GradientCorrection,
|
|
28
22
|
Greenstadt1,
|
|
@@ -32,5 +26,6 @@ from .quasi_newton import (
|
|
|
32
26
|
NewSSM,
|
|
33
27
|
Pearson,
|
|
34
28
|
ProjectedNewtonRaphson,
|
|
29
|
+
ShorR,
|
|
35
30
|
ThomasOptimalMethod,
|
|
36
31
|
)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Literal, Protocol, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList
|
|
7
|
+
from ...utils.linalg.linear_operator import DenseInverse, LinearOperator
|
|
8
|
+
from ..functional import safe_clip
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DampingStrategy(Protocol):
|
|
12
|
+
def __call__(
|
|
13
|
+
self,
|
|
14
|
+
s: torch.Tensor,
|
|
15
|
+
y: torch.Tensor,
|
|
16
|
+
g: torch.Tensor,
|
|
17
|
+
H: LinearOperator,
|
|
18
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
19
|
+
return s, y
|
|
20
|
+
|
|
21
|
+
def _sy_Hs_sHs(s:torch.Tensor, y:torch.Tensor, H:LinearOperator):
|
|
22
|
+
if isinstance(H, DenseInverse):
|
|
23
|
+
Hs = H.solve(y)
|
|
24
|
+
sHs = y.dot(Hs)
|
|
25
|
+
else:
|
|
26
|
+
Hs = H.matvec(s)
|
|
27
|
+
sHs = s.dot(Hs)
|
|
28
|
+
|
|
29
|
+
return s.dot(y), Hs, sHs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def powell_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u=0.2):
|
|
34
|
+
# here H is hessian! not the inverse
|
|
35
|
+
|
|
36
|
+
sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
|
|
37
|
+
if sy < u*sHs:
|
|
38
|
+
phi = ((1-u) * sHs) / safe_clip((sHs - sy))
|
|
39
|
+
s = phi * s + (1 - phi) * Hs
|
|
40
|
+
|
|
41
|
+
return s, y
|
|
42
|
+
|
|
43
|
+
def double_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u1=0.2, u2=1/3):
|
|
44
|
+
# Goldfarb, Donald, Yi Ren, and Achraf Bahamou. "Practical quasi-newton methods for training deep neural networks." Advances in Neural Information Processing Systems 33 (2020): 2386-2396.
|
|
45
|
+
|
|
46
|
+
# Powell’s damping on H
|
|
47
|
+
sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
|
|
48
|
+
if sy < u1*sHs:
|
|
49
|
+
phi = ((1-u1) * sHs) / safe_clip(sHs - sy)
|
|
50
|
+
s = phi * s + (1 - phi) * Hs
|
|
51
|
+
|
|
52
|
+
# Powell’s damping with B = I
|
|
53
|
+
sy = s.dot(y)
|
|
54
|
+
ss = s.dot(s)
|
|
55
|
+
|
|
56
|
+
if sy < u2*ss:
|
|
57
|
+
phi = ((1-u2) * ss) / safe_clip(ss - sy)
|
|
58
|
+
y = phi * y + (1 - phi) * s
|
|
59
|
+
|
|
60
|
+
return s, y
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
_DAMPING_KEYS = Literal["powell", "double"]
|
|
65
|
+
_DAMPING_STRATEGIES: dict[_DAMPING_KEYS, DampingStrategy] = {
|
|
66
|
+
"powell": powell_damping,
|
|
67
|
+
"double": double_damping,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
DampingStrategyType = _DAMPING_KEYS | DampingStrategy | None
|
|
72
|
+
|
|
73
|
+
@overload
|
|
74
|
+
def apply_damping(
|
|
75
|
+
strategy: DampingStrategyType,
|
|
76
|
+
s: torch.Tensor,
|
|
77
|
+
y: torch.Tensor,
|
|
78
|
+
g: torch.Tensor,
|
|
79
|
+
H: LinearOperator,
|
|
80
|
+
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
81
|
+
@overload
|
|
82
|
+
def apply_damping(
|
|
83
|
+
strategy: DampingStrategyType,
|
|
84
|
+
s: TensorList,
|
|
85
|
+
y: TensorList,
|
|
86
|
+
g: TensorList,
|
|
87
|
+
H: LinearOperator,
|
|
88
|
+
) -> tuple[TensorList, TensorList]: ...
|
|
89
|
+
def apply_damping(
|
|
90
|
+
strategy: DampingStrategyType,
|
|
91
|
+
s,
|
|
92
|
+
y,
|
|
93
|
+
g,
|
|
94
|
+
H: LinearOperator,
|
|
95
|
+
):
|
|
96
|
+
if strategy is None: return s, y
|
|
97
|
+
if isinstance(strategy, str): strategy = _DAMPING_STRATEGIES[strategy]
|
|
98
|
+
|
|
99
|
+
if isinstance(s, TensorList):
|
|
100
|
+
assert isinstance(y, TensorList) and isinstance(g, TensorList)
|
|
101
|
+
s_vec, y_vec = strategy(s.to_vec(), y.to_vec(), g.to_vec(), H)
|
|
102
|
+
return s.from_vec(s_vec), y.from_vec(y_vec)
|
|
103
|
+
|
|
104
|
+
assert isinstance(y, torch.Tensor) and isinstance(g, torch.Tensor)
|
|
105
|
+
return strategy(s, y, g, H)
|