torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. tests/test_opts.py +4 -10
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +12 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/conjugate_gradient/cg.py +16 -16
  16. torchzero/modules/experimental/__init__.py +1 -0
  17. torchzero/modules/experimental/newtonnewton.py +5 -5
  18. torchzero/modules/experimental/spsa1.py +93 -0
  19. torchzero/modules/functional.py +7 -0
  20. torchzero/modules/grad_approximation/__init__.py +1 -1
  21. torchzero/modules/grad_approximation/forward_gradient.py +2 -5
  22. torchzero/modules/grad_approximation/rfdm.py +27 -110
  23. torchzero/modules/line_search/__init__.py +1 -1
  24. torchzero/modules/line_search/_polyinterp.py +3 -1
  25. torchzero/modules/line_search/adaptive.py +3 -3
  26. torchzero/modules/line_search/backtracking.py +1 -1
  27. torchzero/modules/line_search/interpolation.py +160 -0
  28. torchzero/modules/line_search/line_search.py +11 -20
  29. torchzero/modules/line_search/scipy.py +15 -3
  30. torchzero/modules/line_search/strong_wolfe.py +3 -5
  31. torchzero/modules/misc/misc.py +2 -2
  32. torchzero/modules/misc/multistep.py +13 -13
  33. torchzero/modules/quasi_newton/__init__.py +2 -0
  34. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  35. torchzero/modules/quasi_newton/sg2.py +292 -0
  36. torchzero/modules/restarts/restars.py +5 -4
  37. torchzero/modules/second_order/__init__.py +6 -3
  38. torchzero/modules/second_order/ifn.py +89 -0
  39. torchzero/modules/second_order/inm.py +105 -0
  40. torchzero/modules/second_order/newton.py +103 -193
  41. torchzero/modules/second_order/newton_cg.py +86 -110
  42. torchzero/modules/second_order/nystrom.py +1 -1
  43. torchzero/modules/second_order/rsn.py +227 -0
  44. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  45. torchzero/modules/trust_region/trust_cg.py +6 -4
  46. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  47. torchzero/modules/zeroth_order/__init__.py +1 -1
  48. torchzero/modules/zeroth_order/cd.py +1 -238
  49. torchzero/utils/derivatives.py +19 -19
  50. torchzero/utils/linalg/linear_operator.py +50 -2
  51. torchzero/utils/optimizer.py +2 -2
  52. torchzero/utils/python_tools.py +1 -0
  53. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  54. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
  55. torchzero/modules/higher_order/__init__.py +0 -1
  56. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  57. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  58. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
@@ -10,34 +10,48 @@ class Wrap(Module):
10
10
  """
11
11
  Wraps a pytorch optimizer to use it as a module.
12
12
 
13
- .. note::
14
- Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.
13
+ Note:
14
+ Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
15
15
 
16
16
  Args:
17
17
  opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
18
- function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
19
- or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
18
+ function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
19
+ or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
20
20
  *args:
21
21
  **kwargs:
22
- Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.
22
+ Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
23
+ use_param_groups:
24
+ Whether to pass settings passed to Modular to the wrapped optimizer.
23
25
 
24
- Example:
25
- wrapping pytorch_optimizer.StableAdamW
26
+ Note that settings to the first parameter are used for all parameters,
27
+ so if you specified per-parameter settings, they will be ignored.
26
28
 
27
- .. code-block:: py
29
+ ### Example:
30
+ wrapping pytorch_optimizer.StableAdamW
28
31
 
29
- from pytorch_optimizer import StableAdamW
30
- opt = tz.Modular(
31
- model.parameters(),
32
- tz.m.Wrap(StableAdamW, lr=1),
33
- tz.m.Cautious(),
34
- tz.m.LR(1e-2)
35
- )
32
+ ```python
36
33
 
34
+ from pytorch_optimizer import StableAdamW
35
+ opt = tz.Modular(
36
+ model.parameters(),
37
+ tz.m.Wrap(StableAdamW, lr=1),
38
+ tz.m.Cautious(),
39
+ tz.m.LR(1e-2)
40
+ )
41
+ ```
37
42
 
38
43
  """
39
- def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
40
- super().__init__()
44
+
45
+ def __init__(
46
+ self,
47
+ opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
48
+ *args,
49
+ use_param_groups: bool = True,
50
+ **kwargs,
51
+ ):
52
+ defaults = dict(use_param_groups=use_param_groups)
53
+ super().__init__(defaults=defaults)
54
+
41
55
  self._opt_fn = opt_fn
42
56
  self._opt_args = args
43
57
  self._opt_kwargs = kwargs
@@ -48,7 +62,7 @@ class Wrap(Module):
48
62
  self.optimizer = self._opt_fn
49
63
 
50
64
  def set_param_groups(self, param_groups):
51
- self._custom_param_groups = param_groups
65
+ self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
52
66
  return super().set_param_groups(param_groups)
53
67
 
54
68
  @torch.no_grad
@@ -61,37 +75,29 @@ class Wrap(Module):
61
75
  param_groups = params if self._custom_param_groups is None else self._custom_param_groups
62
76
  self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
63
77
 
78
+ # set optimizer per-parameter settings
79
+ if self.defaults["use_param_groups"] and var.modular is not None:
80
+ for group in self.optimizer.param_groups:
81
+ first_param = group['params'][0]
82
+ setting = self.settings[first_param]
83
+
84
+ # settings passed in `set_param_groups` are the highest priority
85
+ # schedulers will override defaults but not settings passed in `set_param_groups`
86
+ # this is consistent with how Modular does it.
87
+ if self._custom_param_groups is not None:
88
+ setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
89
+
90
+ group.update(setting)
91
+
64
92
  # set grad to update
65
93
  orig_grad = [p.grad for p in params]
66
94
  for p, u in zip(params, var.get_update()):
67
95
  p.grad = u
68
96
 
69
- # if this module is last, can step with _opt directly
70
- # direct step can't be applied if next module is LR but _opt doesn't support lr,
71
- # and if there are multiple different per-parameter lrs (would be annoying to support)
72
- if var.is_last and (
73
- (var.last_module_lrs is None)
74
- or
75
- (('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
76
- ):
77
- lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
78
-
79
- # update optimizer lr with desired lr
80
- if lr != 1:
81
- self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
82
- for g in self.optimizer.param_groups:
83
- g['__original_lr__'] = g['lr']
84
- g['lr'] = g['lr'] * lr
85
-
86
- # step
97
+ # if this is last module, simply use optimizer to update parameters
98
+ if var.modular is not None and self is var.modular.modules[-1]:
87
99
  self.optimizer.step()
88
100
 
89
- # restore original lr
90
- if lr != 1:
91
- self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
92
- for g in self.optimizer.param_groups:
93
- g['lr'] = g.pop('__original_lr__')
94
-
95
101
  # restore grad
96
102
  for p, g in zip(params, orig_grad):
97
103
  p.grad = g
@@ -100,6 +106,7 @@ class Wrap(Module):
100
106
  return var
101
107
 
102
108
  # this is not the last module, meaning update is difference in parameters
109
+ # and passed to next module
103
110
  params_before_step = [p.clone() for p in params]
104
111
  self.optimizer.step() # step and update params
105
112
  for p, g in zip(params, orig_grad):
@@ -1 +1 @@
1
- from .cd import CD, CCD, CCDLS
1
+ from .cd import CD
@@ -9,11 +9,10 @@ import torch
9
9
 
10
10
  from ...core import Module
11
11
  from ...utils import NumberList, TensorList
12
- from ..line_search.adaptive import adaptive_tracking
13
12
 
14
13
  class CD(Module):
15
14
  """Coordinate descent. Proposes a descent direction along a single coordinate.
16
- You can then put a line search such as ``tz.m.ScipyMinimizeScalar``, or a fixed step size.
15
+ A line search such as ``tz.m.ScipyMinimizeScalar(maxiter=8)`` or a fixed step size can be used after this.
17
16
 
18
17
  Args:
19
18
  h (float, optional): finite difference step size. Defaults to 1e-3.
@@ -121,239 +120,3 @@ class CD(Module):
121
120
  var.update = update
122
121
  return var
123
122
 
124
-
125
- def _icd_get_idx(self: Module, params: TensorList):
126
- ndim = params.global_numel()
127
- igrad = self.get_state(params, "igrad", cls=TensorList)
128
-
129
- # -------------------------- 1st n steps fill igrad -------------------------- #
130
- index = self.global_state.get('index', 0)
131
- self.global_state['index'] = index + 1
132
- if index < ndim:
133
- return index, igrad
134
-
135
- # ------------------ select randomly weighted by magnitudes ------------------ #
136
- igrad_abs = igrad.abs()
137
- gmin = igrad_abs.global_min()
138
- gmax = igrad_abs.global_max()
139
-
140
- pmin, pmax, pow = self.get_settings(params, "pmin", "pmax", "pow", cls=NumberList)
141
-
142
- p: TensorList = ((igrad_abs - gmin) / (gmax - gmin)) ** pow # pyright:ignore[reportOperatorIssue]
143
- p.mul_(pmax-pmin).add_(pmin)
144
-
145
- if 'np_gen' not in self.global_state:
146
- self.global_state['np_gen'] = np.random.default_rng(0)
147
- np_gen = self.global_state['np_gen']
148
-
149
- p_vec = p.to_vec()
150
- p_sum = p_vec.sum()
151
- if p_sum > 1e-12:
152
- return np_gen.choice(ndim, p=p_vec.div_(p_sum).numpy(force=True)), igrad
153
-
154
- # --------------------- sum is too small, do cycle again --------------------- #
155
- self.global_state.clear()
156
- self.clear_state_keys('h_vec', 'igrad', 'alphas')
157
-
158
- if 'generator' not in self.global_state:
159
- self.global_state['generator'] = random.Random(0)
160
- generator = self.global_state['generator']
161
- return generator.randrange(0, p_vec.numel()), igrad
162
-
163
- class CCD(Module):
164
- """Cumulative coordinate descent. This updates one gradient coordinate at a time and accumulates it
165
- to the update direction. The coordinate updated is random weighted by magnitudes of current update direction.
166
- As update direction ceases to be a descent direction due to old accumulated coordinates, it is decayed.
167
-
168
- Args:
169
- pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
170
- pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
171
- pow (int, optional): power transform to probabilities. Defaults to 2.
172
- decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
173
- decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
174
- nplus (float, optional): step size increase on successful steps. Defaults to 1.5.
175
- nminus (float, optional): step size increase on unsuccessful steps. Defaults to 0.75.
176
- """
177
- def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay:float=0.8, decay2:float=0.2, nplus=1.5, nminus=0.75):
178
-
179
- defaults = dict(pmin=pmin, pmax=pmax, pow=pow, decay=decay, decay2=decay2, nplus=nplus, nminus=nminus)
180
- super().__init__(defaults)
181
-
182
- @torch.no_grad
183
- def step(self, var):
184
- closure = var.closure
185
- if closure is None:
186
- raise RuntimeError("CD requires closure")
187
-
188
- params = TensorList(var.params)
189
- p_prev = self.get_state(params, "p_prev", init=params, cls=TensorList)
190
-
191
- f_0 = var.get_loss(False)
192
- step_size = self.global_state.get('step_size', 1)
193
-
194
- # ------------------------ hard reset on infinite loss ----------------------- #
195
- if not math.isfinite(f_0):
196
- del self.global_state['f_prev']
197
- var.update = params - p_prev
198
- self.global_state.clear()
199
- self.state.clear()
200
- self.global_state["step_size"] = step_size / 10
201
- return var
202
-
203
- # ---------------------------- soft reset if stuck --------------------------- #
204
- if "igrad" in self.state[params[0]]:
205
- n_bad = self.global_state.get('n_bad', 0)
206
-
207
- f_prev = self.global_state.get("f_prev", None)
208
- if f_prev is not None:
209
-
210
- decay2 = self.defaults["decay2"]
211
- decay = self.global_state.get("decay", self.defaults["decay"])
212
-
213
- if f_0 >= f_prev:
214
-
215
- igrad = self.get_state(params, "igrad", cls=TensorList)
216
- del self.global_state['f_prev']
217
-
218
- # undo previous update
219
- var.update = params - p_prev
220
-
221
- # increment n_bad
222
- self.global_state['n_bad'] = n_bad + 1
223
-
224
- # decay step size
225
- self.global_state['step_size'] = step_size * self.defaults["nminus"]
226
-
227
- # soft reset
228
- if n_bad > 0:
229
- igrad *= decay
230
- self.global_state["decay"] = decay*decay2
231
- self.global_state['n_bad'] = 0
232
-
233
- return var
234
-
235
- else:
236
- # increase step size and reset n_bad
237
- self.global_state['step_size'] = step_size * self.defaults["nplus"]
238
- self.global_state['n_bad'] = 0
239
- self.global_state["decay"] = self.defaults["decay"]
240
-
241
- self.global_state['f_prev'] = float(f_0)
242
-
243
- # ------------------------------ determine index ----------------------------- #
244
- idx, igrad = _icd_get_idx(self, params)
245
-
246
- # -------------------------- find descent direction -------------------------- #
247
- h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
248
- h = float(h_vec.flat_get(idx))
249
-
250
- params.flat_set_lambda_(idx, lambda x: x + h)
251
- f_p = closure(False)
252
-
253
- params.flat_set_lambda_(idx, lambda x: x - 2*h)
254
- f_n = closure(False)
255
- params.flat_set_lambda_(idx, lambda x: x + h)
256
-
257
- # ---------------------------------- adapt h --------------------------------- #
258
- if f_0 <= f_p and f_0 <= f_n:
259
- h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
260
- else:
261
- if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
262
- h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))
263
-
264
- # ------------------------------- update igrad ------------------------------- #
265
- if f_0 < f_p and f_0 < f_n: alpha = 0
266
- else: alpha = (f_p - f_n) / (2*h)
267
-
268
- igrad.flat_set_(idx, alpha)
269
-
270
- # ----------------------------- create the update ---------------------------- #
271
- var.update = igrad * step_size
272
- p_prev.copy_(params)
273
- return var
274
-
275
-
276
- class CCDLS(Module):
277
- """CCD with line search instead of adaptive step size.
278
-
279
- Args:
280
- pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
281
- pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
282
- pow (int, optional): power transform to probabilities. Defaults to 2.
283
- decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
284
- decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
285
- maxiter (int, optional): max number of line search iterations.
286
- """
287
- def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay=0.8, decay2=0.2, maxiter=10, ):
288
- defaults = dict(pmin=pmin, pmax=pmax, pow=pow, maxiter=maxiter, decay=decay, decay2=decay2)
289
- super().__init__(defaults)
290
-
291
- @torch.no_grad
292
- def step(self, var):
293
- closure = var.closure
294
- if closure is None:
295
- raise RuntimeError("CD requires closure")
296
-
297
- params = TensorList(var.params)
298
- finfo = torch.finfo(params[0].dtype)
299
- f_0 = var.get_loss(False)
300
-
301
- # ------------------------------ determine index ----------------------------- #
302
- idx, igrad = _icd_get_idx(self, params)
303
-
304
- # -------------------------- find descent direction -------------------------- #
305
- h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
306
- h = float(h_vec.flat_get(idx))
307
-
308
- params.flat_set_lambda_(idx, lambda x: x + h)
309
- f_p = closure(False)
310
-
311
- params.flat_set_lambda_(idx, lambda x: x - 2*h)
312
- f_n = closure(False)
313
- params.flat_set_lambda_(idx, lambda x: x + h)
314
-
315
- # ---------------------------------- adapt h --------------------------------- #
316
- if f_0 <= f_p and f_0 <= f_n:
317
- h_vec.flat_set_lambda_(idx, lambda x: max(x/2, finfo.tiny * 2))
318
- else:
319
- # here eps, not tiny
320
- if abs(f_0 - f_n) < finfo.eps or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
321
- h_vec.flat_set_lambda_(idx, lambda x: min(x*2, finfo.max / 2))
322
-
323
- # ------------------------------- update igrad ------------------------------- #
324
- if f_0 < f_p and f_0 < f_n: alpha = 0
325
- else: alpha = (f_p - f_n) / (2*h)
326
-
327
- igrad.flat_set_(idx, alpha)
328
-
329
- # -------------------------------- line search ------------------------------- #
330
- x0 = params.clone()
331
- def f(a):
332
- params.sub_(igrad, alpha=a)
333
- loss = closure(False)
334
- params.copy_(x0)
335
- return loss
336
-
337
- a_prev = self.global_state.get('a_prev', 1)
338
- a, f_a, niter = adaptive_tracking(f, a_prev, maxiter=self.defaults['maxiter'], f_0=f_0)
339
- if (a is None) or (not math.isfinite(a)) or (not math.isfinite(f_a)):
340
- a = 0
341
-
342
- # -------------------------------- set a_prev -------------------------------- #
343
- decay2 = self.defaults["decay2"]
344
- decay = self.global_state.get("decay", self.defaults["decay"])
345
-
346
- if abs(a) > finfo.tiny * 2:
347
- assert f_a < f_0
348
- self.global_state['a_prev'] = max(min(a, finfo.max / 2), finfo.tiny * 2)
349
- self.global_state["decay"] = self.defaults["decay"]
350
-
351
- # ---------------------------- soft reset on fail ---------------------------- #
352
- else:
353
- igrad *= decay
354
- self.global_state["decay"] = decay*decay2
355
- self.global_state['a_prev'] = a_prev / 2
356
-
357
- # -------------------------------- set update -------------------------------- #
358
- var.update = igrad * a
359
- return var
@@ -5,13 +5,13 @@ import torch.autograd.forward_ad as fwAD
5
5
 
6
6
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
7
7
 
8
- def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
9
- flat_input = torch.cat([i.reshape(-1) for i in output])
10
- grad_ouputs = torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype)
8
+ def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
9
+ flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
10
+ grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
11
11
  jac = []
12
- for i in range(flat_input.numel()):
12
+ for i in range(flat_outputs.numel()):
13
13
  jac.append(torch.autograd.grad(
14
- flat_input,
14
+ flat_outputs,
15
15
  wrt,
16
16
  grad_ouputs[i],
17
17
  retain_graph=True,
@@ -22,12 +22,12 @@ def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], creat
22
22
  return [torch.stack(z) for z in zip(*jac)]
23
23
 
24
24
 
25
- def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
26
- flat_input = torch.cat([i.reshape(-1) for i in output])
25
+ def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
26
+ flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
27
27
  return torch.autograd.grad(
28
- flat_input,
28
+ flat_outputs,
29
29
  wrt,
30
- torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype),
30
+ torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype),
31
31
  retain_graph=True,
32
32
  create_graph=create_graph,
33
33
  allow_unused=True,
@@ -51,13 +51,13 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
51
51
  return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
52
52
 
53
53
 
54
- def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
54
+ def jacobian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
55
55
  """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
56
56
  Returns a sequence of tensors with the length as `wrt`.
57
57
  Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
58
58
 
59
59
  Args:
60
- input (Sequence[torch.Tensor]): input sequence of tensors.
60
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
61
61
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
62
62
  create_graph (bool, optional):
63
63
  pytorch option, if True, graph of the derivative will be constructed,
@@ -68,16 +68,16 @@ def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], cr
68
68
  Returns:
69
69
  sequence of tensors with the length as `wrt`.
70
70
  """
71
- if batched: return _jacobian_batched(output, wrt, create_graph)
72
- return _jacobian(output, wrt, create_graph)
71
+ if batched: return _jacobian_batched(outputs, wrt, create_graph)
72
+ return _jacobian(outputs, wrt, create_graph)
73
73
 
74
- def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
74
+ def jacobian_and_hessian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
75
75
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
76
76
  Calculating hessian requires calculating the jacobian. So this function is more efficient than
77
77
  calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
78
78
 
79
79
  Args:
80
- input (Sequence[torch.Tensor]): input sequence of tensors.
80
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
81
81
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
82
82
  create_graph (bool, optional):
83
83
  pytorch option, if True, graph of the derivative will be constructed,
@@ -87,7 +87,7 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
87
87
  Returns:
88
88
  tuple with jacobians sequence and hessians sequence.
89
89
  """
90
- jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
90
+ jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
91
91
  return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
92
92
 
93
93
 
@@ -96,13 +96,13 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
96
96
  # Note - I only tested this for cases where input is a scalar."""
97
97
  # return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
98
98
 
99
- def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
99
+ def jacobian_and_hessian_mat_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
100
100
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
101
101
  Calculating hessian requires calculating the jacobian. So this function is more efficient than
102
102
  calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
103
103
 
104
104
  Args:
105
- input (Sequence[torch.Tensor]): input sequence of tensors.
105
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
106
106
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
107
107
  create_graph (bool, optional):
108
108
  pytorch option, if True, graph of the derivative will be constructed,
@@ -112,7 +112,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
112
112
  Returns:
113
113
  tuple with jacobians sequence and hessians sequence.
114
114
  """
115
- jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
115
+ jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
116
116
  H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
117
117
  return flatten_jacobian(jac), flatten_jacobian(H_list)
118
118
 
@@ -35,8 +35,8 @@ class LinearOperator(ABC):
35
35
  """solve with a norm bound on x"""
36
36
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
37
37
 
38
- def update(self, *args, **kwargs) -> None:
39
- raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
38
+ # def update(self, *args, **kwargs) -> None:
39
+ # raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
40
40
 
41
41
  def add(self, x: torch.Tensor) -> "LinearOperator":
42
42
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
@@ -298,6 +298,7 @@ class AtA(LinearOperator):
298
298
  class AAT(LinearOperator):
299
299
  def __init__(self, A: torch.Tensor):
300
300
  self.A = A
301
+ self.device = self.A.device; self.dtype = self.A.dtype
301
302
 
302
303
  def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
303
304
  def rmatvec(self, x): return self.matvec(x)
@@ -327,3 +328,50 @@ class AAT(LinearOperator):
327
328
  n = self.A.size(1)
328
329
  return (n,n)
329
330
 
331
+
332
+ class Sketched(LinearOperator):
333
+ """A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
334
+
335
+ Where A is (n, n) and S is (n, sketch_size).
336
+ """
337
+ def __init__(self, S: torch.Tensor, A_proj: torch.Tensor):
338
+ self.S = S
339
+ self.A_proj = A_proj
340
+ self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
341
+
342
+
343
+ def matvec(self, x):
344
+ x_proj = self.S.T @ x
345
+ Ax_proj = self.A_proj @ x_proj
346
+ return self.S @ Ax_proj
347
+
348
+ def rmatvec(self, x):
349
+ x_proj = self.S.T @ x
350
+ ATx_proj = self.A_proj.mH @ x_proj
351
+ return self.S @ ATx_proj
352
+
353
+
354
+ def matmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, x])) # pylint:disable=not-callable
355
+ def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, x])) # pylint:disable=not-callable
356
+
357
+
358
+ def is_dense(self): return False
359
+ def to_tensor(self): return self.S @ self.A_proj @ self.S.T
360
+ def transpose(self): return Sketched(self.S, self.A_proj.mH)
361
+
362
+ def add_diagonal(self, x):
363
+ """this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
364
+ if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
365
+ if isinstance(x, (int,float)): x = torch.full((self.A_proj.shape[0],), fill_value=x, device=self.A_proj.device, dtype=self.A_proj.dtype)
366
+ return Sketched(S=self.S, A_proj=self.A_proj + x.diag_embed())
367
+
368
+ def solve(self, b):
369
+ return self.S @ torch.linalg.lstsq(self.A_proj, self.S.T @ b).solution # pylint:disable=not-callable
370
+
371
+ def inv(self):
372
+ return Sketched(S=self.S, A_proj=torch.linalg.pinv(self.A_proj)) # pylint:disable=not-callable
373
+
374
+ def size(self):
375
+ n = self.S.size(0)
376
+ return (n,n)
377
+
@@ -110,7 +110,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
110
110
  for i, param in enumerate(params):
111
111
  s = state[param]
112
112
  if key not in s:
113
- if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
113
+ if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
114
114
  s[key] = _make_initial_state_value(param, init, i)
115
115
  values.append(s[key])
116
116
  return values
@@ -125,7 +125,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
125
125
  s = state[param]
126
126
  for k_i, key in enumerate(keys):
127
127
  if key not in s:
128
- if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
128
+ if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
129
129
  k_init = init[k_i] if isinstance(init, (list,tuple)) else init
130
130
  s[key] = _make_initial_state_value(param, k_init, i)
131
131
  values[k_i].append(s[key])
@@ -67,3 +67,4 @@ def safe_dict_update_(d1_:dict, d2:dict):
67
67
  inter = set(d1_.keys()).intersection(d2.keys())
68
68
  if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
69
69
  d1_.update(d2)
70
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.13
3
+ Version: 0.3.15
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero