torchzero 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +4 -3
- torchzero/core/__init__.py +4 -1
- torchzero/core/chain.py +50 -0
- torchzero/core/functional.py +37 -0
- torchzero/core/modular.py +237 -0
- torchzero/core/module.py +8 -599
- torchzero/core/reformulation.py +3 -1
- torchzero/core/transform.py +7 -5
- torchzero/core/var.py +376 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/experimental/__init__.py +1 -0
- torchzero/modules/experimental/newtonnewton.py +5 -5
- torchzero/modules/experimental/spsa1.py +2 -2
- torchzero/modules/functional.py +7 -0
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +1 -1
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +11 -20
- torchzero/modules/line_search/strong_wolfe.py +3 -3
- torchzero/modules/misc/misc.py +2 -2
- torchzero/modules/misc/multistep.py +13 -13
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/quasi_newton.py +15 -6
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +89 -0
- torchzero/modules/second_order/inm.py +105 -0
- torchzero/modules/second_order/newton.py +103 -193
- torchzero/modules/second_order/nystrom.py +1 -1
- torchzero/modules/second_order/rsn.py +227 -0
- torchzero/modules/wrappers/optim_wrapper.py +49 -42
- torchzero/utils/derivatives.py +19 -19
- torchzero/utils/linalg/linear_operator.py +50 -2
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/RECORD +44 -36
- torchzero/modules/higher_order/__init__.py +0 -1
- /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections import deque
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, apply_transform
|
|
9
|
+
from ...utils import Distributions, TensorList, vec_to_tensors
|
|
10
|
+
from ...utils.linalg.linear_operator import Sketched
|
|
11
|
+
from .newton import _newton_step
|
|
12
|
+
|
|
13
|
+
def _qr_orthonormalize(A:torch.Tensor):
|
|
14
|
+
m,n = A.shape
|
|
15
|
+
if m < n:
|
|
16
|
+
q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
|
|
17
|
+
return q.T
|
|
18
|
+
else:
|
|
19
|
+
q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
|
|
20
|
+
return q
|
|
21
|
+
|
|
22
|
+
def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
23
|
+
return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
|
|
24
|
+
|
|
25
|
+
def _gaussian_sketch(m, n, dtype, device, generator):
|
|
26
|
+
return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
|
|
27
|
+
|
|
28
|
+
class RSN(Module):
|
|
29
|
+
"""Randomized Subspace Newton. Performs a Newton step in a random subspace.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
sketch_size (int):
|
|
33
|
+
size of the random sketch. This many hessian-vector products will need to be evaluated each step.
|
|
34
|
+
sketch_type (str, optional):
|
|
35
|
+
- "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
|
|
36
|
+
- "gaussian" - random gaussian (not orthonormal) basis.
|
|
37
|
+
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
|
|
38
|
+
- "mixed" - random orthonormal basis but with three directions set to gradient, slow EMA and fast EMA (default).
|
|
39
|
+
damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
|
|
40
|
+
hvp_method (str, optional):
|
|
41
|
+
How to compute hessian-matrix product:
|
|
42
|
+
- "batched" - uses batched autograd
|
|
43
|
+
- "autograd" - uses unbatched autograd
|
|
44
|
+
- "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
|
|
45
|
+
- "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
|
|
46
|
+
|
|
47
|
+
. Defaults to "batched".
|
|
48
|
+
h (float, optional): finite difference step size. Defaults to 1e-2.
|
|
49
|
+
use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
|
|
50
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
51
|
+
H_tfm (Callable | None, optional):
|
|
52
|
+
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
53
|
+
|
|
54
|
+
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
55
|
+
which must be True if transform inverted the hessian and False otherwise.
|
|
56
|
+
|
|
57
|
+
Or it returns a single tensor which is used as the update.
|
|
58
|
+
|
|
59
|
+
Defaults to None.
|
|
60
|
+
eigval_fn (Callable | None, optional):
|
|
61
|
+
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
62
|
+
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
63
|
+
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
64
|
+
inner (Chainable | None, optional): preconditions output of this module. Defaults to None.
|
|
65
|
+
|
|
66
|
+
### Examples
|
|
67
|
+
|
|
68
|
+
RSN with line search
|
|
69
|
+
```python
|
|
70
|
+
opt = tz.Modular(
|
|
71
|
+
model.parameters(),
|
|
72
|
+
tz.m.RSN(),
|
|
73
|
+
tz.m.Backtracking()
|
|
74
|
+
)
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
RSN with trust region
|
|
78
|
+
```python
|
|
79
|
+
opt = tz.Modular(
|
|
80
|
+
model.parameters(),
|
|
81
|
+
tz.m.LevenbergMarquardt(tz.m.RSN()),
|
|
82
|
+
)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
References:
|
|
87
|
+
1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
|
|
88
|
+
2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
sketch_size: int,
|
|
94
|
+
sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
|
|
95
|
+
damping:float=0,
|
|
96
|
+
hvp_method: Literal["batched", "autograd", "forward", "central"] = "batched",
|
|
97
|
+
h: float = 1e-2,
|
|
98
|
+
use_lstsq: bool = True,
|
|
99
|
+
update_freq: int = 1,
|
|
100
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
101
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
102
|
+
seed: int | None = None,
|
|
103
|
+
inner: Chainable | None = None,
|
|
104
|
+
):
|
|
105
|
+
defaults = dict(sketch_size=sketch_size, sketch_type=sketch_type,seed=seed,hvp_method=hvp_method, h=h, damping=damping, use_lstsq=use_lstsq, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
|
|
106
|
+
super().__init__(defaults)
|
|
107
|
+
|
|
108
|
+
if inner is not None:
|
|
109
|
+
self.set_child("inner", inner)
|
|
110
|
+
|
|
111
|
+
@torch.no_grad
|
|
112
|
+
def update(self, var):
|
|
113
|
+
step = self.global_state.get('step', 0)
|
|
114
|
+
self.global_state['step'] = step + 1
|
|
115
|
+
|
|
116
|
+
if step % self.defaults['update_freq'] == 0:
|
|
117
|
+
|
|
118
|
+
closure = var.closure
|
|
119
|
+
if closure is None:
|
|
120
|
+
raise RuntimeError("RSN requires closure")
|
|
121
|
+
params = var.params
|
|
122
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
123
|
+
|
|
124
|
+
ndim = sum(p.numel() for p in params)
|
|
125
|
+
|
|
126
|
+
device=params[0].device
|
|
127
|
+
dtype=params[0].dtype
|
|
128
|
+
|
|
129
|
+
# sample sketch matrix S: (ndim, sketch_size)
|
|
130
|
+
sketch_size = min(self.defaults["sketch_size"], ndim)
|
|
131
|
+
sketch_type = self.defaults["sketch_type"]
|
|
132
|
+
hvp_method = self.defaults["hvp_method"]
|
|
133
|
+
|
|
134
|
+
if sketch_type in ('normal', 'gaussian'):
|
|
135
|
+
S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
136
|
+
|
|
137
|
+
elif sketch_type == 'orthonormal':
|
|
138
|
+
S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
139
|
+
|
|
140
|
+
elif sketch_type == 'common_directions':
|
|
141
|
+
# Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
142
|
+
g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
|
|
143
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
144
|
+
|
|
145
|
+
# initialize directions deque
|
|
146
|
+
if "directions" not in self.global_state:
|
|
147
|
+
|
|
148
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
149
|
+
if g_norm < torch.finfo(g.dtype).tiny * 2:
|
|
150
|
+
g = torch.randn_like(g)
|
|
151
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
152
|
+
|
|
153
|
+
self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
|
|
154
|
+
S = self.global_state["directions"][0].unsqueeze(1)
|
|
155
|
+
|
|
156
|
+
# add new steepest descent direction orthonormal to existing columns
|
|
157
|
+
else:
|
|
158
|
+
S = torch.stack(tuple(self.global_state["directions"]), dim=1)
|
|
159
|
+
p = g - S @ (S.T @ g)
|
|
160
|
+
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
161
|
+
if p_norm > torch.finfo(p.dtype).tiny * 2:
|
|
162
|
+
p = p / p_norm
|
|
163
|
+
self.global_state["directions"].append(p)
|
|
164
|
+
S = torch.cat([S, p.unsqueeze(1)], dim=1)
|
|
165
|
+
|
|
166
|
+
elif sketch_type == "mixed":
|
|
167
|
+
g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
|
|
168
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
169
|
+
|
|
170
|
+
if "slow_ema" not in self.global_state:
|
|
171
|
+
self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
|
|
172
|
+
self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
|
|
173
|
+
|
|
174
|
+
slow_ema = self.global_state["slow_ema"]
|
|
175
|
+
fast_ema = self.global_state["fast_ema"]
|
|
176
|
+
slow_ema.lerp_(g, 0.001)
|
|
177
|
+
fast_ema.lerp_(g, 0.1)
|
|
178
|
+
|
|
179
|
+
S = torch.stack([g, slow_ema, fast_ema], dim=1)
|
|
180
|
+
if sketch_size > 3:
|
|
181
|
+
S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
|
|
182
|
+
S = torch.cat([S, S_random], dim=1)
|
|
183
|
+
|
|
184
|
+
S = _qr_orthonormalize(S)
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
raise ValueError(f'Unknown sketch_type {sketch_type}')
|
|
188
|
+
|
|
189
|
+
# form sketched hessian
|
|
190
|
+
HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
|
|
191
|
+
H_sketched = S.T @ HS
|
|
192
|
+
|
|
193
|
+
self.global_state["H_sketched"] = H_sketched
|
|
194
|
+
self.global_state["S"] = S
|
|
195
|
+
|
|
196
|
+
def apply(self, var):
|
|
197
|
+
S: torch.Tensor = self.global_state["S"]
|
|
198
|
+
d_proj = _newton_step(
|
|
199
|
+
var=var,
|
|
200
|
+
H=self.global_state["H_sketched"],
|
|
201
|
+
damping=self.defaults["damping"],
|
|
202
|
+
inner=self.children.get("inner", None),
|
|
203
|
+
H_tfm=self.defaults["H_tfm"],
|
|
204
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
205
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
206
|
+
g_proj = lambda g: S.T @ g
|
|
207
|
+
)
|
|
208
|
+
d = S @ d_proj
|
|
209
|
+
var.update = vec_to_tensors(d, var.params)
|
|
210
|
+
|
|
211
|
+
return var
|
|
212
|
+
|
|
213
|
+
def get_H(self, var=...):
|
|
214
|
+
eigval_fn = self.defaults["eigval_fn"]
|
|
215
|
+
H_sketched: torch.Tensor = self.global_state["H_sketched"]
|
|
216
|
+
S: torch.Tensor = self.global_state["S"]
|
|
217
|
+
|
|
218
|
+
if eigval_fn is not None:
|
|
219
|
+
try:
|
|
220
|
+
L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
|
|
221
|
+
L: torch.Tensor = eigval_fn(L)
|
|
222
|
+
H_sketched = Q @ L.diag_embed() @ Q.mH
|
|
223
|
+
|
|
224
|
+
except torch.linalg.LinAlgError:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
return Sketched(S, H_sketched)
|
|
@@ -10,34 +10,48 @@ class Wrap(Module):
|
|
|
10
10
|
"""
|
|
11
11
|
Wraps a pytorch optimizer to use it as a module.
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
Custom param groups are supported only by
|
|
13
|
+
Note:
|
|
14
|
+
Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
17
|
opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
|
|
18
|
-
function that takes in parameters and returns the optimizer, for example
|
|
19
|
-
or
|
|
18
|
+
function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
|
|
19
|
+
or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
|
|
20
20
|
*args:
|
|
21
21
|
**kwargs:
|
|
22
|
-
Extra args to be passed to opt_fn. The function is called as
|
|
22
|
+
Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
|
|
23
|
+
use_param_groups:
|
|
24
|
+
Whether to pass settings passed to Modular to the wrapped optimizer.
|
|
23
25
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
+
Note that settings to the first parameter are used for all parameters,
|
|
27
|
+
so if you specified per-parameter settings, they will be ignored.
|
|
26
28
|
|
|
27
|
-
|
|
29
|
+
### Example:
|
|
30
|
+
wrapping pytorch_optimizer.StableAdamW
|
|
28
31
|
|
|
29
|
-
|
|
30
|
-
opt = tz.Modular(
|
|
31
|
-
model.parameters(),
|
|
32
|
-
tz.m.Wrap(StableAdamW, lr=1),
|
|
33
|
-
tz.m.Cautious(),
|
|
34
|
-
tz.m.LR(1e-2)
|
|
35
|
-
)
|
|
32
|
+
```python
|
|
36
33
|
|
|
34
|
+
from pytorch_optimizer import StableAdamW
|
|
35
|
+
opt = tz.Modular(
|
|
36
|
+
model.parameters(),
|
|
37
|
+
tz.m.Wrap(StableAdamW, lr=1),
|
|
38
|
+
tz.m.Cautious(),
|
|
39
|
+
tz.m.LR(1e-2)
|
|
40
|
+
)
|
|
41
|
+
```
|
|
37
42
|
|
|
38
43
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
|
|
48
|
+
*args,
|
|
49
|
+
use_param_groups: bool = True,
|
|
50
|
+
**kwargs,
|
|
51
|
+
):
|
|
52
|
+
defaults = dict(use_param_groups=use_param_groups)
|
|
53
|
+
super().__init__(defaults=defaults)
|
|
54
|
+
|
|
41
55
|
self._opt_fn = opt_fn
|
|
42
56
|
self._opt_args = args
|
|
43
57
|
self._opt_kwargs = kwargs
|
|
@@ -48,7 +62,7 @@ class Wrap(Module):
|
|
|
48
62
|
self.optimizer = self._opt_fn
|
|
49
63
|
|
|
50
64
|
def set_param_groups(self, param_groups):
|
|
51
|
-
self._custom_param_groups = param_groups
|
|
65
|
+
self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
|
|
52
66
|
return super().set_param_groups(param_groups)
|
|
53
67
|
|
|
54
68
|
@torch.no_grad
|
|
@@ -61,37 +75,29 @@ class Wrap(Module):
|
|
|
61
75
|
param_groups = params if self._custom_param_groups is None else self._custom_param_groups
|
|
62
76
|
self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
|
|
63
77
|
|
|
78
|
+
# set optimizer per-parameter settings
|
|
79
|
+
if self.defaults["use_param_groups"] and var.modular is not None:
|
|
80
|
+
for group in self.optimizer.param_groups:
|
|
81
|
+
first_param = group['params'][0]
|
|
82
|
+
setting = self.settings[first_param]
|
|
83
|
+
|
|
84
|
+
# settings passed in `set_param_groups` are the highest priority
|
|
85
|
+
# schedulers will override defaults but not settings passed in `set_param_groups`
|
|
86
|
+
# this is consistent with how Modular does it.
|
|
87
|
+
if self._custom_param_groups is not None:
|
|
88
|
+
setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
|
|
89
|
+
|
|
90
|
+
group.update(setting)
|
|
91
|
+
|
|
64
92
|
# set grad to update
|
|
65
93
|
orig_grad = [p.grad for p in params]
|
|
66
94
|
for p, u in zip(params, var.get_update()):
|
|
67
95
|
p.grad = u
|
|
68
96
|
|
|
69
|
-
# if this
|
|
70
|
-
|
|
71
|
-
# and if there are multiple different per-parameter lrs (would be annoying to support)
|
|
72
|
-
if var.is_last and (
|
|
73
|
-
(var.last_module_lrs is None)
|
|
74
|
-
or
|
|
75
|
-
(('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
|
|
76
|
-
):
|
|
77
|
-
lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
|
|
78
|
-
|
|
79
|
-
# update optimizer lr with desired lr
|
|
80
|
-
if lr != 1:
|
|
81
|
-
self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
|
|
82
|
-
for g in self.optimizer.param_groups:
|
|
83
|
-
g['__original_lr__'] = g['lr']
|
|
84
|
-
g['lr'] = g['lr'] * lr
|
|
85
|
-
|
|
86
|
-
# step
|
|
97
|
+
# if this is last module, simply use optimizer to update parameters
|
|
98
|
+
if var.modular is not None and self is var.modular.modules[-1]:
|
|
87
99
|
self.optimizer.step()
|
|
88
100
|
|
|
89
|
-
# restore original lr
|
|
90
|
-
if lr != 1:
|
|
91
|
-
self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
|
|
92
|
-
for g in self.optimizer.param_groups:
|
|
93
|
-
g['lr'] = g.pop('__original_lr__')
|
|
94
|
-
|
|
95
101
|
# restore grad
|
|
96
102
|
for p, g in zip(params, orig_grad):
|
|
97
103
|
p.grad = g
|
|
@@ -100,6 +106,7 @@ class Wrap(Module):
|
|
|
100
106
|
return var
|
|
101
107
|
|
|
102
108
|
# this is not the last module, meaning update is difference in parameters
|
|
109
|
+
# and passed to next module
|
|
103
110
|
params_before_step = [p.clone() for p in params]
|
|
104
111
|
self.optimizer.step() # step and update params
|
|
105
112
|
for p, g in zip(params, orig_grad):
|
torchzero/utils/derivatives.py
CHANGED
|
@@ -5,13 +5,13 @@ import torch.autograd.forward_ad as fwAD
|
|
|
5
5
|
|
|
6
6
|
from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
|
|
7
7
|
|
|
8
|
-
def _jacobian(
|
|
9
|
-
|
|
10
|
-
grad_ouputs = torch.eye(len(
|
|
8
|
+
def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
9
|
+
flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
|
|
10
|
+
grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
|
|
11
11
|
jac = []
|
|
12
|
-
for i in range(
|
|
12
|
+
for i in range(flat_outputs.numel()):
|
|
13
13
|
jac.append(torch.autograd.grad(
|
|
14
|
-
|
|
14
|
+
flat_outputs,
|
|
15
15
|
wrt,
|
|
16
16
|
grad_ouputs[i],
|
|
17
17
|
retain_graph=True,
|
|
@@ -22,12 +22,12 @@ def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], creat
|
|
|
22
22
|
return [torch.stack(z) for z in zip(*jac)]
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def _jacobian_batched(
|
|
26
|
-
|
|
25
|
+
def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
26
|
+
flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
|
|
27
27
|
return torch.autograd.grad(
|
|
28
|
-
|
|
28
|
+
flat_outputs,
|
|
29
29
|
wrt,
|
|
30
|
-
torch.eye(len(
|
|
30
|
+
torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype),
|
|
31
31
|
retain_graph=True,
|
|
32
32
|
create_graph=create_graph,
|
|
33
33
|
allow_unused=True,
|
|
@@ -51,13 +51,13 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
|
51
51
|
return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
def jacobian_wrt(
|
|
54
|
+
def jacobian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
|
|
55
55
|
"""Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
|
|
56
56
|
Returns a sequence of tensors with the length as `wrt`.
|
|
57
57
|
Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
|
|
58
58
|
|
|
59
59
|
Args:
|
|
60
|
-
|
|
60
|
+
outputs (Sequence[torch.Tensor]): input sequence of tensors.
|
|
61
61
|
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
62
62
|
create_graph (bool, optional):
|
|
63
63
|
pytorch option, if True, graph of the derivative will be constructed,
|
|
@@ -68,16 +68,16 @@ def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], cr
|
|
|
68
68
|
Returns:
|
|
69
69
|
sequence of tensors with the length as `wrt`.
|
|
70
70
|
"""
|
|
71
|
-
if batched: return _jacobian_batched(
|
|
72
|
-
return _jacobian(
|
|
71
|
+
if batched: return _jacobian_batched(outputs, wrt, create_graph)
|
|
72
|
+
return _jacobian(outputs, wrt, create_graph)
|
|
73
73
|
|
|
74
|
-
def jacobian_and_hessian_wrt(
|
|
74
|
+
def jacobian_and_hessian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
|
|
75
75
|
"""Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
|
|
76
76
|
Calculating hessian requires calculating the jacobian. So this function is more efficient than
|
|
77
77
|
calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
|
|
80
|
+
outputs (Sequence[torch.Tensor]): input sequence of tensors.
|
|
81
81
|
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
82
82
|
create_graph (bool, optional):
|
|
83
83
|
pytorch option, if True, graph of the derivative will be constructed,
|
|
@@ -87,7 +87,7 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
|
|
|
87
87
|
Returns:
|
|
88
88
|
tuple with jacobians sequence and hessians sequence.
|
|
89
89
|
"""
|
|
90
|
-
jac = jacobian_wrt(
|
|
90
|
+
jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
|
|
91
91
|
return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
92
92
|
|
|
93
93
|
|
|
@@ -96,13 +96,13 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
|
|
|
96
96
|
# Note - I only tested this for cases where input is a scalar."""
|
|
97
97
|
# return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
|
|
98
98
|
|
|
99
|
-
def jacobian_and_hessian_mat_wrt(
|
|
99
|
+
def jacobian_and_hessian_mat_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
|
|
100
100
|
"""Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
|
|
101
101
|
Calculating hessian requires calculating the jacobian. So this function is more efficient than
|
|
102
102
|
calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
|
|
103
103
|
|
|
104
104
|
Args:
|
|
105
|
-
|
|
105
|
+
outputs (Sequence[torch.Tensor]): input sequence of tensors.
|
|
106
106
|
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
107
107
|
create_graph (bool, optional):
|
|
108
108
|
pytorch option, if True, graph of the derivative will be constructed,
|
|
@@ -112,7 +112,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
|
|
|
112
112
|
Returns:
|
|
113
113
|
tuple with jacobians sequence and hessians sequence.
|
|
114
114
|
"""
|
|
115
|
-
jac = jacobian_wrt(
|
|
115
|
+
jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
|
|
116
116
|
H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
117
117
|
return flatten_jacobian(jac), flatten_jacobian(H_list)
|
|
118
118
|
|
|
@@ -35,8 +35,8 @@ class LinearOperator(ABC):
|
|
|
35
35
|
"""solve with a norm bound on x"""
|
|
36
36
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
|
|
37
37
|
|
|
38
|
-
def update(self, *args, **kwargs) -> None:
|
|
39
|
-
|
|
38
|
+
# def update(self, *args, **kwargs) -> None:
|
|
39
|
+
# raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
|
|
40
40
|
|
|
41
41
|
def add(self, x: torch.Tensor) -> "LinearOperator":
|
|
42
42
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
|
|
@@ -298,6 +298,7 @@ class AtA(LinearOperator):
|
|
|
298
298
|
class AAT(LinearOperator):
|
|
299
299
|
def __init__(self, A: torch.Tensor):
|
|
300
300
|
self.A = A
|
|
301
|
+
self.device = self.A.device; self.dtype = self.A.dtype
|
|
301
302
|
|
|
302
303
|
def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
|
|
303
304
|
def rmatvec(self, x): return self.matvec(x)
|
|
@@ -327,3 +328,50 @@ class AAT(LinearOperator):
|
|
|
327
328
|
n = self.A.size(1)
|
|
328
329
|
return (n,n)
|
|
329
330
|
|
|
331
|
+
|
|
332
|
+
class Sketched(LinearOperator):
|
|
333
|
+
"""A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
|
|
334
|
+
|
|
335
|
+
Where A is (n, n) and S is (n, sketch_size).
|
|
336
|
+
"""
|
|
337
|
+
def __init__(self, S: torch.Tensor, A_proj: torch.Tensor):
|
|
338
|
+
self.S = S
|
|
339
|
+
self.A_proj = A_proj
|
|
340
|
+
self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def matvec(self, x):
|
|
344
|
+
x_proj = self.S.T @ x
|
|
345
|
+
Ax_proj = self.A_proj @ x_proj
|
|
346
|
+
return self.S @ Ax_proj
|
|
347
|
+
|
|
348
|
+
def rmatvec(self, x):
|
|
349
|
+
x_proj = self.S.T @ x
|
|
350
|
+
ATx_proj = self.A_proj.mH @ x_proj
|
|
351
|
+
return self.S @ ATx_proj
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def matmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, x])) # pylint:disable=not-callable
|
|
355
|
+
def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, x])) # pylint:disable=not-callable
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def is_dense(self): return False
|
|
359
|
+
def to_tensor(self): return self.S @ self.A_proj @ self.S.T
|
|
360
|
+
def transpose(self): return Sketched(self.S, self.A_proj.mH)
|
|
361
|
+
|
|
362
|
+
def add_diagonal(self, x):
|
|
363
|
+
"""this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
|
|
364
|
+
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
365
|
+
if isinstance(x, (int,float)): x = torch.full((self.A_proj.shape[0],), fill_value=x, device=self.A_proj.device, dtype=self.A_proj.dtype)
|
|
366
|
+
return Sketched(S=self.S, A_proj=self.A_proj + x.diag_embed())
|
|
367
|
+
|
|
368
|
+
def solve(self, b):
|
|
369
|
+
return self.S @ torch.linalg.lstsq(self.A_proj, self.S.T @ b).solution # pylint:disable=not-callable
|
|
370
|
+
|
|
371
|
+
def inv(self):
|
|
372
|
+
return Sketched(S=self.S, A_proj=torch.linalg.pinv(self.A_proj)) # pylint:disable=not-callable
|
|
373
|
+
|
|
374
|
+
def size(self):
|
|
375
|
+
n = self.S.size(0)
|
|
376
|
+
return (n,n)
|
|
377
|
+
|