torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +4 -10
- torchzero/core/__init__.py +4 -1
- torchzero/core/chain.py +50 -0
- torchzero/core/functional.py +37 -0
- torchzero/core/modular.py +237 -0
- torchzero/core/module.py +12 -599
- torchzero/core/reformulation.py +3 -1
- torchzero/core/transform.py +7 -5
- torchzero/core/var.py +376 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/conjugate_gradient/cg.py +16 -16
- torchzero/modules/experimental/__init__.py +1 -0
- torchzero/modules/experimental/newtonnewton.py +5 -5
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/functional.py +7 -0
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +2 -5
- torchzero/modules/grad_approximation/rfdm.py +27 -110
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +1 -1
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +11 -20
- torchzero/modules/line_search/scipy.py +15 -3
- torchzero/modules/line_search/strong_wolfe.py +3 -5
- torchzero/modules/misc/misc.py +2 -2
- torchzero/modules/misc/multistep.py +13 -13
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/quasi_newton.py +15 -6
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +5 -4
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +89 -0
- torchzero/modules/second_order/inm.py +105 -0
- torchzero/modules/second_order/newton.py +103 -193
- torchzero/modules/second_order/newton_cg.py +86 -110
- torchzero/modules/second_order/nystrom.py +1 -1
- torchzero/modules/second_order/rsn.py +227 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +6 -4
- torchzero/modules/wrappers/optim_wrapper.py +49 -42
- torchzero/modules/zeroth_order/__init__.py +1 -1
- torchzero/modules/zeroth_order/cd.py +1 -238
- torchzero/utils/derivatives.py +19 -19
- torchzero/utils/linalg/linear_operator.py +50 -2
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +1 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
- torchzero/modules/higher_order/__init__.py +0 -1
- /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
|
@@ -10,34 +10,48 @@ class Wrap(Module):
|
|
|
10
10
|
"""
|
|
11
11
|
Wraps a pytorch optimizer to use it as a module.
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
Custom param groups are supported only by
|
|
13
|
+
Note:
|
|
14
|
+
Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
17
|
opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
|
|
18
|
-
function that takes in parameters and returns the optimizer, for example
|
|
19
|
-
or
|
|
18
|
+
function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
|
|
19
|
+
or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
|
|
20
20
|
*args:
|
|
21
21
|
**kwargs:
|
|
22
|
-
Extra args to be passed to opt_fn. The function is called as
|
|
22
|
+
Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
|
|
23
|
+
use_param_groups:
|
|
24
|
+
Whether to pass settings passed to Modular to the wrapped optimizer.
|
|
23
25
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
+
Note that settings to the first parameter are used for all parameters,
|
|
27
|
+
so if you specified per-parameter settings, they will be ignored.
|
|
26
28
|
|
|
27
|
-
|
|
29
|
+
### Example:
|
|
30
|
+
wrapping pytorch_optimizer.StableAdamW
|
|
28
31
|
|
|
29
|
-
|
|
30
|
-
opt = tz.Modular(
|
|
31
|
-
model.parameters(),
|
|
32
|
-
tz.m.Wrap(StableAdamW, lr=1),
|
|
33
|
-
tz.m.Cautious(),
|
|
34
|
-
tz.m.LR(1e-2)
|
|
35
|
-
)
|
|
32
|
+
```python
|
|
36
33
|
|
|
34
|
+
from pytorch_optimizer import StableAdamW
|
|
35
|
+
opt = tz.Modular(
|
|
36
|
+
model.parameters(),
|
|
37
|
+
tz.m.Wrap(StableAdamW, lr=1),
|
|
38
|
+
tz.m.Cautious(),
|
|
39
|
+
tz.m.LR(1e-2)
|
|
40
|
+
)
|
|
41
|
+
```
|
|
37
42
|
|
|
38
43
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
|
|
48
|
+
*args,
|
|
49
|
+
use_param_groups: bool = True,
|
|
50
|
+
**kwargs,
|
|
51
|
+
):
|
|
52
|
+
defaults = dict(use_param_groups=use_param_groups)
|
|
53
|
+
super().__init__(defaults=defaults)
|
|
54
|
+
|
|
41
55
|
self._opt_fn = opt_fn
|
|
42
56
|
self._opt_args = args
|
|
43
57
|
self._opt_kwargs = kwargs
|
|
@@ -48,7 +62,7 @@ class Wrap(Module):
|
|
|
48
62
|
self.optimizer = self._opt_fn
|
|
49
63
|
|
|
50
64
|
def set_param_groups(self, param_groups):
|
|
51
|
-
self._custom_param_groups = param_groups
|
|
65
|
+
self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
|
|
52
66
|
return super().set_param_groups(param_groups)
|
|
53
67
|
|
|
54
68
|
@torch.no_grad
|
|
@@ -61,37 +75,29 @@ class Wrap(Module):
|
|
|
61
75
|
param_groups = params if self._custom_param_groups is None else self._custom_param_groups
|
|
62
76
|
self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
|
|
63
77
|
|
|
78
|
+
# set optimizer per-parameter settings
|
|
79
|
+
if self.defaults["use_param_groups"] and var.modular is not None:
|
|
80
|
+
for group in self.optimizer.param_groups:
|
|
81
|
+
first_param = group['params'][0]
|
|
82
|
+
setting = self.settings[first_param]
|
|
83
|
+
|
|
84
|
+
# settings passed in `set_param_groups` are the highest priority
|
|
85
|
+
# schedulers will override defaults but not settings passed in `set_param_groups`
|
|
86
|
+
# this is consistent with how Modular does it.
|
|
87
|
+
if self._custom_param_groups is not None:
|
|
88
|
+
setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
|
|
89
|
+
|
|
90
|
+
group.update(setting)
|
|
91
|
+
|
|
64
92
|
# set grad to update
|
|
65
93
|
orig_grad = [p.grad for p in params]
|
|
66
94
|
for p, u in zip(params, var.get_update()):
|
|
67
95
|
p.grad = u
|
|
68
96
|
|
|
69
|
-
# if this
|
|
70
|
-
|
|
71
|
-
# and if there are multiple different per-parameter lrs (would be annoying to support)
|
|
72
|
-
if var.is_last and (
|
|
73
|
-
(var.last_module_lrs is None)
|
|
74
|
-
or
|
|
75
|
-
(('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
|
|
76
|
-
):
|
|
77
|
-
lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
|
|
78
|
-
|
|
79
|
-
# update optimizer lr with desired lr
|
|
80
|
-
if lr != 1:
|
|
81
|
-
self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
|
|
82
|
-
for g in self.optimizer.param_groups:
|
|
83
|
-
g['__original_lr__'] = g['lr']
|
|
84
|
-
g['lr'] = g['lr'] * lr
|
|
85
|
-
|
|
86
|
-
# step
|
|
97
|
+
# if this is last module, simply use optimizer to update parameters
|
|
98
|
+
if var.modular is not None and self is var.modular.modules[-1]:
|
|
87
99
|
self.optimizer.step()
|
|
88
100
|
|
|
89
|
-
# restore original lr
|
|
90
|
-
if lr != 1:
|
|
91
|
-
self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
|
|
92
|
-
for g in self.optimizer.param_groups:
|
|
93
|
-
g['lr'] = g.pop('__original_lr__')
|
|
94
|
-
|
|
95
101
|
# restore grad
|
|
96
102
|
for p, g in zip(params, orig_grad):
|
|
97
103
|
p.grad = g
|
|
@@ -100,6 +106,7 @@ class Wrap(Module):
|
|
|
100
106
|
return var
|
|
101
107
|
|
|
102
108
|
# this is not the last module, meaning update is difference in parameters
|
|
109
|
+
# and passed to next module
|
|
103
110
|
params_before_step = [p.clone() for p in params]
|
|
104
111
|
self.optimizer.step() # step and update params
|
|
105
112
|
for p, g in zip(params, orig_grad):
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .cd import CD
|
|
1
|
+
from .cd import CD
|
|
@@ -9,11 +9,10 @@ import torch
|
|
|
9
9
|
|
|
10
10
|
from ...core import Module
|
|
11
11
|
from ...utils import NumberList, TensorList
|
|
12
|
-
from ..line_search.adaptive import adaptive_tracking
|
|
13
12
|
|
|
14
13
|
class CD(Module):
|
|
15
14
|
"""Coordinate descent. Proposes a descent direction along a single coordinate.
|
|
16
|
-
|
|
15
|
+
A line search such as ``tz.m.ScipyMinimizeScalar(maxiter=8)`` or a fixed step size can be used after this.
|
|
17
16
|
|
|
18
17
|
Args:
|
|
19
18
|
h (float, optional): finite difference step size. Defaults to 1e-3.
|
|
@@ -121,239 +120,3 @@ class CD(Module):
|
|
|
121
120
|
var.update = update
|
|
122
121
|
return var
|
|
123
122
|
|
|
124
|
-
|
|
125
|
-
def _icd_get_idx(self: Module, params: TensorList):
|
|
126
|
-
ndim = params.global_numel()
|
|
127
|
-
igrad = self.get_state(params, "igrad", cls=TensorList)
|
|
128
|
-
|
|
129
|
-
# -------------------------- 1st n steps fill igrad -------------------------- #
|
|
130
|
-
index = self.global_state.get('index', 0)
|
|
131
|
-
self.global_state['index'] = index + 1
|
|
132
|
-
if index < ndim:
|
|
133
|
-
return index, igrad
|
|
134
|
-
|
|
135
|
-
# ------------------ select randomly weighted by magnitudes ------------------ #
|
|
136
|
-
igrad_abs = igrad.abs()
|
|
137
|
-
gmin = igrad_abs.global_min()
|
|
138
|
-
gmax = igrad_abs.global_max()
|
|
139
|
-
|
|
140
|
-
pmin, pmax, pow = self.get_settings(params, "pmin", "pmax", "pow", cls=NumberList)
|
|
141
|
-
|
|
142
|
-
p: TensorList = ((igrad_abs - gmin) / (gmax - gmin)) ** pow # pyright:ignore[reportOperatorIssue]
|
|
143
|
-
p.mul_(pmax-pmin).add_(pmin)
|
|
144
|
-
|
|
145
|
-
if 'np_gen' not in self.global_state:
|
|
146
|
-
self.global_state['np_gen'] = np.random.default_rng(0)
|
|
147
|
-
np_gen = self.global_state['np_gen']
|
|
148
|
-
|
|
149
|
-
p_vec = p.to_vec()
|
|
150
|
-
p_sum = p_vec.sum()
|
|
151
|
-
if p_sum > 1e-12:
|
|
152
|
-
return np_gen.choice(ndim, p=p_vec.div_(p_sum).numpy(force=True)), igrad
|
|
153
|
-
|
|
154
|
-
# --------------------- sum is too small, do cycle again --------------------- #
|
|
155
|
-
self.global_state.clear()
|
|
156
|
-
self.clear_state_keys('h_vec', 'igrad', 'alphas')
|
|
157
|
-
|
|
158
|
-
if 'generator' not in self.global_state:
|
|
159
|
-
self.global_state['generator'] = random.Random(0)
|
|
160
|
-
generator = self.global_state['generator']
|
|
161
|
-
return generator.randrange(0, p_vec.numel()), igrad
|
|
162
|
-
|
|
163
|
-
class CCD(Module):
|
|
164
|
-
"""Cumulative coordinate descent. This updates one gradient coordinate at a time and accumulates it
|
|
165
|
-
to the update direction. The coordinate updated is random weighted by magnitudes of current update direction.
|
|
166
|
-
As update direction ceases to be a descent direction due to old accumulated coordinates, it is decayed.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
|
|
170
|
-
pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
|
|
171
|
-
pow (int, optional): power transform to probabilities. Defaults to 2.
|
|
172
|
-
decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
|
|
173
|
-
decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
|
|
174
|
-
nplus (float, optional): step size increase on successful steps. Defaults to 1.5.
|
|
175
|
-
nminus (float, optional): step size increase on unsuccessful steps. Defaults to 0.75.
|
|
176
|
-
"""
|
|
177
|
-
def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay:float=0.8, decay2:float=0.2, nplus=1.5, nminus=0.75):
|
|
178
|
-
|
|
179
|
-
defaults = dict(pmin=pmin, pmax=pmax, pow=pow, decay=decay, decay2=decay2, nplus=nplus, nminus=nminus)
|
|
180
|
-
super().__init__(defaults)
|
|
181
|
-
|
|
182
|
-
@torch.no_grad
|
|
183
|
-
def step(self, var):
|
|
184
|
-
closure = var.closure
|
|
185
|
-
if closure is None:
|
|
186
|
-
raise RuntimeError("CD requires closure")
|
|
187
|
-
|
|
188
|
-
params = TensorList(var.params)
|
|
189
|
-
p_prev = self.get_state(params, "p_prev", init=params, cls=TensorList)
|
|
190
|
-
|
|
191
|
-
f_0 = var.get_loss(False)
|
|
192
|
-
step_size = self.global_state.get('step_size', 1)
|
|
193
|
-
|
|
194
|
-
# ------------------------ hard reset on infinite loss ----------------------- #
|
|
195
|
-
if not math.isfinite(f_0):
|
|
196
|
-
del self.global_state['f_prev']
|
|
197
|
-
var.update = params - p_prev
|
|
198
|
-
self.global_state.clear()
|
|
199
|
-
self.state.clear()
|
|
200
|
-
self.global_state["step_size"] = step_size / 10
|
|
201
|
-
return var
|
|
202
|
-
|
|
203
|
-
# ---------------------------- soft reset if stuck --------------------------- #
|
|
204
|
-
if "igrad" in self.state[params[0]]:
|
|
205
|
-
n_bad = self.global_state.get('n_bad', 0)
|
|
206
|
-
|
|
207
|
-
f_prev = self.global_state.get("f_prev", None)
|
|
208
|
-
if f_prev is not None:
|
|
209
|
-
|
|
210
|
-
decay2 = self.defaults["decay2"]
|
|
211
|
-
decay = self.global_state.get("decay", self.defaults["decay"])
|
|
212
|
-
|
|
213
|
-
if f_0 >= f_prev:
|
|
214
|
-
|
|
215
|
-
igrad = self.get_state(params, "igrad", cls=TensorList)
|
|
216
|
-
del self.global_state['f_prev']
|
|
217
|
-
|
|
218
|
-
# undo previous update
|
|
219
|
-
var.update = params - p_prev
|
|
220
|
-
|
|
221
|
-
# increment n_bad
|
|
222
|
-
self.global_state['n_bad'] = n_bad + 1
|
|
223
|
-
|
|
224
|
-
# decay step size
|
|
225
|
-
self.global_state['step_size'] = step_size * self.defaults["nminus"]
|
|
226
|
-
|
|
227
|
-
# soft reset
|
|
228
|
-
if n_bad > 0:
|
|
229
|
-
igrad *= decay
|
|
230
|
-
self.global_state["decay"] = decay*decay2
|
|
231
|
-
self.global_state['n_bad'] = 0
|
|
232
|
-
|
|
233
|
-
return var
|
|
234
|
-
|
|
235
|
-
else:
|
|
236
|
-
# increase step size and reset n_bad
|
|
237
|
-
self.global_state['step_size'] = step_size * self.defaults["nplus"]
|
|
238
|
-
self.global_state['n_bad'] = 0
|
|
239
|
-
self.global_state["decay"] = self.defaults["decay"]
|
|
240
|
-
|
|
241
|
-
self.global_state['f_prev'] = float(f_0)
|
|
242
|
-
|
|
243
|
-
# ------------------------------ determine index ----------------------------- #
|
|
244
|
-
idx, igrad = _icd_get_idx(self, params)
|
|
245
|
-
|
|
246
|
-
# -------------------------- find descent direction -------------------------- #
|
|
247
|
-
h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
|
|
248
|
-
h = float(h_vec.flat_get(idx))
|
|
249
|
-
|
|
250
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
251
|
-
f_p = closure(False)
|
|
252
|
-
|
|
253
|
-
params.flat_set_lambda_(idx, lambda x: x - 2*h)
|
|
254
|
-
f_n = closure(False)
|
|
255
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
256
|
-
|
|
257
|
-
# ---------------------------------- adapt h --------------------------------- #
|
|
258
|
-
if f_0 <= f_p and f_0 <= f_n:
|
|
259
|
-
h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
|
|
260
|
-
else:
|
|
261
|
-
if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
|
|
262
|
-
h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))
|
|
263
|
-
|
|
264
|
-
# ------------------------------- update igrad ------------------------------- #
|
|
265
|
-
if f_0 < f_p and f_0 < f_n: alpha = 0
|
|
266
|
-
else: alpha = (f_p - f_n) / (2*h)
|
|
267
|
-
|
|
268
|
-
igrad.flat_set_(idx, alpha)
|
|
269
|
-
|
|
270
|
-
# ----------------------------- create the update ---------------------------- #
|
|
271
|
-
var.update = igrad * step_size
|
|
272
|
-
p_prev.copy_(params)
|
|
273
|
-
return var
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
class CCDLS(Module):
|
|
277
|
-
"""CCD with line search instead of adaptive step size.
|
|
278
|
-
|
|
279
|
-
Args:
|
|
280
|
-
pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
|
|
281
|
-
pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
|
|
282
|
-
pow (int, optional): power transform to probabilities. Defaults to 2.
|
|
283
|
-
decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
|
|
284
|
-
decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
|
|
285
|
-
maxiter (int, optional): max number of line search iterations.
|
|
286
|
-
"""
|
|
287
|
-
def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay=0.8, decay2=0.2, maxiter=10, ):
|
|
288
|
-
defaults = dict(pmin=pmin, pmax=pmax, pow=pow, maxiter=maxiter, decay=decay, decay2=decay2)
|
|
289
|
-
super().__init__(defaults)
|
|
290
|
-
|
|
291
|
-
@torch.no_grad
|
|
292
|
-
def step(self, var):
|
|
293
|
-
closure = var.closure
|
|
294
|
-
if closure is None:
|
|
295
|
-
raise RuntimeError("CD requires closure")
|
|
296
|
-
|
|
297
|
-
params = TensorList(var.params)
|
|
298
|
-
finfo = torch.finfo(params[0].dtype)
|
|
299
|
-
f_0 = var.get_loss(False)
|
|
300
|
-
|
|
301
|
-
# ------------------------------ determine index ----------------------------- #
|
|
302
|
-
idx, igrad = _icd_get_idx(self, params)
|
|
303
|
-
|
|
304
|
-
# -------------------------- find descent direction -------------------------- #
|
|
305
|
-
h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
|
|
306
|
-
h = float(h_vec.flat_get(idx))
|
|
307
|
-
|
|
308
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
309
|
-
f_p = closure(False)
|
|
310
|
-
|
|
311
|
-
params.flat_set_lambda_(idx, lambda x: x - 2*h)
|
|
312
|
-
f_n = closure(False)
|
|
313
|
-
params.flat_set_lambda_(idx, lambda x: x + h)
|
|
314
|
-
|
|
315
|
-
# ---------------------------------- adapt h --------------------------------- #
|
|
316
|
-
if f_0 <= f_p and f_0 <= f_n:
|
|
317
|
-
h_vec.flat_set_lambda_(idx, lambda x: max(x/2, finfo.tiny * 2))
|
|
318
|
-
else:
|
|
319
|
-
# here eps, not tiny
|
|
320
|
-
if abs(f_0 - f_n) < finfo.eps or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
|
|
321
|
-
h_vec.flat_set_lambda_(idx, lambda x: min(x*2, finfo.max / 2))
|
|
322
|
-
|
|
323
|
-
# ------------------------------- update igrad ------------------------------- #
|
|
324
|
-
if f_0 < f_p and f_0 < f_n: alpha = 0
|
|
325
|
-
else: alpha = (f_p - f_n) / (2*h)
|
|
326
|
-
|
|
327
|
-
igrad.flat_set_(idx, alpha)
|
|
328
|
-
|
|
329
|
-
# -------------------------------- line search ------------------------------- #
|
|
330
|
-
x0 = params.clone()
|
|
331
|
-
def f(a):
|
|
332
|
-
params.sub_(igrad, alpha=a)
|
|
333
|
-
loss = closure(False)
|
|
334
|
-
params.copy_(x0)
|
|
335
|
-
return loss
|
|
336
|
-
|
|
337
|
-
a_prev = self.global_state.get('a_prev', 1)
|
|
338
|
-
a, f_a, niter = adaptive_tracking(f, a_prev, maxiter=self.defaults['maxiter'], f_0=f_0)
|
|
339
|
-
if (a is None) or (not math.isfinite(a)) or (not math.isfinite(f_a)):
|
|
340
|
-
a = 0
|
|
341
|
-
|
|
342
|
-
# -------------------------------- set a_prev -------------------------------- #
|
|
343
|
-
decay2 = self.defaults["decay2"]
|
|
344
|
-
decay = self.global_state.get("decay", self.defaults["decay"])
|
|
345
|
-
|
|
346
|
-
if abs(a) > finfo.tiny * 2:
|
|
347
|
-
assert f_a < f_0
|
|
348
|
-
self.global_state['a_prev'] = max(min(a, finfo.max / 2), finfo.tiny * 2)
|
|
349
|
-
self.global_state["decay"] = self.defaults["decay"]
|
|
350
|
-
|
|
351
|
-
# ---------------------------- soft reset on fail ---------------------------- #
|
|
352
|
-
else:
|
|
353
|
-
igrad *= decay
|
|
354
|
-
self.global_state["decay"] = decay*decay2
|
|
355
|
-
self.global_state['a_prev'] = a_prev / 2
|
|
356
|
-
|
|
357
|
-
# -------------------------------- set update -------------------------------- #
|
|
358
|
-
var.update = igrad * a
|
|
359
|
-
return var
|
torchzero/utils/derivatives.py
CHANGED
|
@@ -5,13 +5,13 @@ import torch.autograd.forward_ad as fwAD
|
|
|
5
5
|
|
|
6
6
|
from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
|
|
7
7
|
|
|
8
|
-
def _jacobian(
|
|
9
|
-
|
|
10
|
-
grad_ouputs = torch.eye(len(
|
|
8
|
+
def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
9
|
+
flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
|
|
10
|
+
grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
|
|
11
11
|
jac = []
|
|
12
|
-
for i in range(
|
|
12
|
+
for i in range(flat_outputs.numel()):
|
|
13
13
|
jac.append(torch.autograd.grad(
|
|
14
|
-
|
|
14
|
+
flat_outputs,
|
|
15
15
|
wrt,
|
|
16
16
|
grad_ouputs[i],
|
|
17
17
|
retain_graph=True,
|
|
@@ -22,12 +22,12 @@ def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], creat
|
|
|
22
22
|
return [torch.stack(z) for z in zip(*jac)]
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def _jacobian_batched(
|
|
26
|
-
|
|
25
|
+
def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
26
|
+
flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
|
|
27
27
|
return torch.autograd.grad(
|
|
28
|
-
|
|
28
|
+
flat_outputs,
|
|
29
29
|
wrt,
|
|
30
|
-
torch.eye(len(
|
|
30
|
+
torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype),
|
|
31
31
|
retain_graph=True,
|
|
32
32
|
create_graph=create_graph,
|
|
33
33
|
allow_unused=True,
|
|
@@ -51,13 +51,13 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
|
51
51
|
return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
def jacobian_wrt(
|
|
54
|
+
def jacobian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
|
|
55
55
|
"""Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
|
|
56
56
|
Returns a sequence of tensors with the length as `wrt`.
|
|
57
57
|
Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
|
|
58
58
|
|
|
59
59
|
Args:
|
|
60
|
-
|
|
60
|
+
outputs (Sequence[torch.Tensor]): input sequence of tensors.
|
|
61
61
|
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
62
62
|
create_graph (bool, optional):
|
|
63
63
|
pytorch option, if True, graph of the derivative will be constructed,
|
|
@@ -68,16 +68,16 @@ def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], cr
|
|
|
68
68
|
Returns:
|
|
69
69
|
sequence of tensors with the length as `wrt`.
|
|
70
70
|
"""
|
|
71
|
-
if batched: return _jacobian_batched(
|
|
72
|
-
return _jacobian(
|
|
71
|
+
if batched: return _jacobian_batched(outputs, wrt, create_graph)
|
|
72
|
+
return _jacobian(outputs, wrt, create_graph)
|
|
73
73
|
|
|
74
|
-
def jacobian_and_hessian_wrt(
|
|
74
|
+
def jacobian_and_hessian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
|
|
75
75
|
"""Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
|
|
76
76
|
Calculating hessian requires calculating the jacobian. So this function is more efficient than
|
|
77
77
|
calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
|
|
78
78
|
|
|
79
79
|
Args:
|
|
80
|
-
|
|
80
|
+
outputs (Sequence[torch.Tensor]): input sequence of tensors.
|
|
81
81
|
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
82
82
|
create_graph (bool, optional):
|
|
83
83
|
pytorch option, if True, graph of the derivative will be constructed,
|
|
@@ -87,7 +87,7 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
|
|
|
87
87
|
Returns:
|
|
88
88
|
tuple with jacobians sequence and hessians sequence.
|
|
89
89
|
"""
|
|
90
|
-
jac = jacobian_wrt(
|
|
90
|
+
jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
|
|
91
91
|
return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
92
92
|
|
|
93
93
|
|
|
@@ -96,13 +96,13 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
|
|
|
96
96
|
# Note - I only tested this for cases where input is a scalar."""
|
|
97
97
|
# return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
|
|
98
98
|
|
|
99
|
-
def jacobian_and_hessian_mat_wrt(
|
|
99
|
+
def jacobian_and_hessian_mat_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
|
|
100
100
|
"""Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
|
|
101
101
|
Calculating hessian requires calculating the jacobian. So this function is more efficient than
|
|
102
102
|
calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
|
|
103
103
|
|
|
104
104
|
Args:
|
|
105
|
-
|
|
105
|
+
outputs (Sequence[torch.Tensor]): input sequence of tensors.
|
|
106
106
|
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
107
107
|
create_graph (bool, optional):
|
|
108
108
|
pytorch option, if True, graph of the derivative will be constructed,
|
|
@@ -112,7 +112,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
|
|
|
112
112
|
Returns:
|
|
113
113
|
tuple with jacobians sequence and hessians sequence.
|
|
114
114
|
"""
|
|
115
|
-
jac = jacobian_wrt(
|
|
115
|
+
jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
|
|
116
116
|
H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
117
117
|
return flatten_jacobian(jac), flatten_jacobian(H_list)
|
|
118
118
|
|
|
@@ -35,8 +35,8 @@ class LinearOperator(ABC):
|
|
|
35
35
|
"""solve with a norm bound on x"""
|
|
36
36
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
|
|
37
37
|
|
|
38
|
-
def update(self, *args, **kwargs) -> None:
|
|
39
|
-
|
|
38
|
+
# def update(self, *args, **kwargs) -> None:
|
|
39
|
+
# raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
|
|
40
40
|
|
|
41
41
|
def add(self, x: torch.Tensor) -> "LinearOperator":
|
|
42
42
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
|
|
@@ -298,6 +298,7 @@ class AtA(LinearOperator):
|
|
|
298
298
|
class AAT(LinearOperator):
|
|
299
299
|
def __init__(self, A: torch.Tensor):
|
|
300
300
|
self.A = A
|
|
301
|
+
self.device = self.A.device; self.dtype = self.A.dtype
|
|
301
302
|
|
|
302
303
|
def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
|
|
303
304
|
def rmatvec(self, x): return self.matvec(x)
|
|
@@ -327,3 +328,50 @@ class AAT(LinearOperator):
|
|
|
327
328
|
n = self.A.size(1)
|
|
328
329
|
return (n,n)
|
|
329
330
|
|
|
331
|
+
|
|
332
|
+
class Sketched(LinearOperator):
|
|
333
|
+
"""A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
|
|
334
|
+
|
|
335
|
+
Where A is (n, n) and S is (n, sketch_size).
|
|
336
|
+
"""
|
|
337
|
+
def __init__(self, S: torch.Tensor, A_proj: torch.Tensor):
|
|
338
|
+
self.S = S
|
|
339
|
+
self.A_proj = A_proj
|
|
340
|
+
self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def matvec(self, x):
|
|
344
|
+
x_proj = self.S.T @ x
|
|
345
|
+
Ax_proj = self.A_proj @ x_proj
|
|
346
|
+
return self.S @ Ax_proj
|
|
347
|
+
|
|
348
|
+
def rmatvec(self, x):
|
|
349
|
+
x_proj = self.S.T @ x
|
|
350
|
+
ATx_proj = self.A_proj.mH @ x_proj
|
|
351
|
+
return self.S @ ATx_proj
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def matmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, x])) # pylint:disable=not-callable
|
|
355
|
+
def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, x])) # pylint:disable=not-callable
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def is_dense(self): return False
|
|
359
|
+
def to_tensor(self): return self.S @ self.A_proj @ self.S.T
|
|
360
|
+
def transpose(self): return Sketched(self.S, self.A_proj.mH)
|
|
361
|
+
|
|
362
|
+
def add_diagonal(self, x):
|
|
363
|
+
"""this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
|
|
364
|
+
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
365
|
+
if isinstance(x, (int,float)): x = torch.full((self.A_proj.shape[0],), fill_value=x, device=self.A_proj.device, dtype=self.A_proj.dtype)
|
|
366
|
+
return Sketched(S=self.S, A_proj=self.A_proj + x.diag_embed())
|
|
367
|
+
|
|
368
|
+
def solve(self, b):
|
|
369
|
+
return self.S @ torch.linalg.lstsq(self.A_proj, self.S.T @ b).solution # pylint:disable=not-callable
|
|
370
|
+
|
|
371
|
+
def inv(self):
|
|
372
|
+
return Sketched(S=self.S, A_proj=torch.linalg.pinv(self.A_proj)) # pylint:disable=not-callable
|
|
373
|
+
|
|
374
|
+
def size(self):
|
|
375
|
+
n = self.S.size(0)
|
|
376
|
+
return (n,n)
|
|
377
|
+
|
torchzero/utils/optimizer.py
CHANGED
|
@@ -110,7 +110,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
|
|
|
110
110
|
for i, param in enumerate(params):
|
|
111
111
|
s = state[param]
|
|
112
112
|
if key not in s:
|
|
113
|
-
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
113
|
+
if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
|
|
114
114
|
s[key] = _make_initial_state_value(param, init, i)
|
|
115
115
|
values.append(s[key])
|
|
116
116
|
return values
|
|
@@ -125,7 +125,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
|
|
|
125
125
|
s = state[param]
|
|
126
126
|
for k_i, key in enumerate(keys):
|
|
127
127
|
if key not in s:
|
|
128
|
-
if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
|
|
128
|
+
if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
|
|
129
129
|
k_init = init[k_i] if isinstance(init, (list,tuple)) else init
|
|
130
130
|
s[key] = _make_initial_state_value(param, k_init, i)
|
|
131
131
|
values[k_i].append(s[key])
|
torchzero/utils/python_tools.py
CHANGED