torchzero 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. tests/test_opts.py +4 -3
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +8 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/experimental/__init__.py +1 -0
  16. torchzero/modules/experimental/newtonnewton.py +5 -5
  17. torchzero/modules/experimental/spsa1.py +2 -2
  18. torchzero/modules/functional.py +7 -0
  19. torchzero/modules/line_search/__init__.py +1 -1
  20. torchzero/modules/line_search/_polyinterp.py +3 -1
  21. torchzero/modules/line_search/adaptive.py +3 -3
  22. torchzero/modules/line_search/backtracking.py +1 -1
  23. torchzero/modules/line_search/interpolation.py +160 -0
  24. torchzero/modules/line_search/line_search.py +11 -20
  25. torchzero/modules/line_search/strong_wolfe.py +3 -3
  26. torchzero/modules/misc/misc.py +2 -2
  27. torchzero/modules/misc/multistep.py +13 -13
  28. torchzero/modules/quasi_newton/__init__.py +2 -0
  29. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  30. torchzero/modules/quasi_newton/sg2.py +292 -0
  31. torchzero/modules/second_order/__init__.py +6 -3
  32. torchzero/modules/second_order/ifn.py +89 -0
  33. torchzero/modules/second_order/inm.py +105 -0
  34. torchzero/modules/second_order/newton.py +103 -193
  35. torchzero/modules/second_order/nystrom.py +1 -1
  36. torchzero/modules/second_order/rsn.py +227 -0
  37. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  38. torchzero/utils/derivatives.py +19 -19
  39. torchzero/utils/linalg/linear_operator.py +50 -2
  40. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  41. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/RECORD +44 -36
  42. torchzero/modules/higher_order/__init__.py +0 -1
  43. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  44. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  45. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
1
+ import math
2
+ from collections import deque
3
+ from collections.abc import Callable
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, apply_transform
9
+ from ...utils import Distributions, TensorList, vec_to_tensors
10
+ from ...utils.linalg.linear_operator import Sketched
11
+ from .newton import _newton_step
12
+
13
+ def _qr_orthonormalize(A:torch.Tensor):
14
+ m,n = A.shape
15
+ if m < n:
16
+ q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
17
+ return q.T
18
+ else:
19
+ q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
20
+ return q
21
+
22
+ def _orthonormal_sketch(m, n, dtype, device, generator):
23
+ return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
24
+
25
+ def _gaussian_sketch(m, n, dtype, device, generator):
26
+ return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
27
+
28
+ class RSN(Module):
29
+ """Randomized Subspace Newton. Performs a Newton step in a random subspace.
30
+
31
+ Args:
32
+ sketch_size (int):
33
+ size of the random sketch. This many hessian-vector products will need to be evaluated each step.
34
+ sketch_type (str, optional):
35
+ - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
36
+ - "gaussian" - random gaussian (not orthonormal) basis.
37
+ - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
38
+ - "mixed" - random orthonormal basis but with three directions set to gradient, slow EMA and fast EMA (default).
39
+ damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
40
+ hvp_method (str, optional):
41
+ How to compute hessian-matrix product:
42
+ - "batched" - uses batched autograd
43
+ - "autograd" - uses unbatched autograd
44
+ - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
45
+ - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
46
+
47
+ . Defaults to "batched".
48
+ h (float, optional): finite difference step size. Defaults to 1e-2.
49
+ use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
50
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
51
+ H_tfm (Callable | None, optional):
52
+ optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
53
+
54
+ must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
55
+ which must be True if transform inverted the hessian and False otherwise.
56
+
57
+ Or it returns a single tensor which is used as the update.
58
+
59
+ Defaults to None.
60
+ eigval_fn (Callable | None, optional):
61
+ optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
62
+ If this is specified, eigendecomposition will be used to invert the hessian.
63
+ seed (int | None, optional): seed for random generator. Defaults to None.
64
+ inner (Chainable | None, optional): preconditions output of this module. Defaults to None.
65
+
66
+ ### Examples
67
+
68
+ RSN with line search
69
+ ```python
70
+ opt = tz.Modular(
71
+ model.parameters(),
72
+ tz.m.RSN(),
73
+ tz.m.Backtracking()
74
+ )
75
+ ```
76
+
77
+ RSN with trust region
78
+ ```python
79
+ opt = tz.Modular(
80
+ model.parameters(),
81
+ tz.m.LevenbergMarquardt(tz.m.RSN()),
82
+ )
83
+ ```
84
+
85
+
86
+ References:
87
+ 1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
88
+ 2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ sketch_size: int,
94
+ sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
95
+ damping:float=0,
96
+ hvp_method: Literal["batched", "autograd", "forward", "central"] = "batched",
97
+ h: float = 1e-2,
98
+ use_lstsq: bool = True,
99
+ update_freq: int = 1,
100
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
101
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
102
+ seed: int | None = None,
103
+ inner: Chainable | None = None,
104
+ ):
105
+ defaults = dict(sketch_size=sketch_size, sketch_type=sketch_type,seed=seed,hvp_method=hvp_method, h=h, damping=damping, use_lstsq=use_lstsq, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
106
+ super().__init__(defaults)
107
+
108
+ if inner is not None:
109
+ self.set_child("inner", inner)
110
+
111
+ @torch.no_grad
112
+ def update(self, var):
113
+ step = self.global_state.get('step', 0)
114
+ self.global_state['step'] = step + 1
115
+
116
+ if step % self.defaults['update_freq'] == 0:
117
+
118
+ closure = var.closure
119
+ if closure is None:
120
+ raise RuntimeError("RSN requires closure")
121
+ params = var.params
122
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
123
+
124
+ ndim = sum(p.numel() for p in params)
125
+
126
+ device=params[0].device
127
+ dtype=params[0].dtype
128
+
129
+ # sample sketch matrix S: (ndim, sketch_size)
130
+ sketch_size = min(self.defaults["sketch_size"], ndim)
131
+ sketch_type = self.defaults["sketch_type"]
132
+ hvp_method = self.defaults["hvp_method"]
133
+
134
+ if sketch_type in ('normal', 'gaussian'):
135
+ S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
136
+
137
+ elif sketch_type == 'orthonormal':
138
+ S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
139
+
140
+ elif sketch_type == 'common_directions':
141
+ # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
142
+ g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
143
+ g = torch.cat([t.ravel() for t in g_list])
144
+
145
+ # initialize directions deque
146
+ if "directions" not in self.global_state:
147
+
148
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
149
+ if g_norm < torch.finfo(g.dtype).tiny * 2:
150
+ g = torch.randn_like(g)
151
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
152
+
153
+ self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
154
+ S = self.global_state["directions"][0].unsqueeze(1)
155
+
156
+ # add new steepest descent direction orthonormal to existing columns
157
+ else:
158
+ S = torch.stack(tuple(self.global_state["directions"]), dim=1)
159
+ p = g - S @ (S.T @ g)
160
+ p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
161
+ if p_norm > torch.finfo(p.dtype).tiny * 2:
162
+ p = p / p_norm
163
+ self.global_state["directions"].append(p)
164
+ S = torch.cat([S, p.unsqueeze(1)], dim=1)
165
+
166
+ elif sketch_type == "mixed":
167
+ g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
168
+ g = torch.cat([t.ravel() for t in g_list])
169
+
170
+ if "slow_ema" not in self.global_state:
171
+ self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
172
+ self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
173
+
174
+ slow_ema = self.global_state["slow_ema"]
175
+ fast_ema = self.global_state["fast_ema"]
176
+ slow_ema.lerp_(g, 0.001)
177
+ fast_ema.lerp_(g, 0.1)
178
+
179
+ S = torch.stack([g, slow_ema, fast_ema], dim=1)
180
+ if sketch_size > 3:
181
+ S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
182
+ S = torch.cat([S, S_random], dim=1)
183
+
184
+ S = _qr_orthonormalize(S)
185
+
186
+ else:
187
+ raise ValueError(f'Unknown sketch_type {sketch_type}')
188
+
189
+ # form sketched hessian
190
+ HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
191
+ H_sketched = S.T @ HS
192
+
193
+ self.global_state["H_sketched"] = H_sketched
194
+ self.global_state["S"] = S
195
+
196
+ def apply(self, var):
197
+ S: torch.Tensor = self.global_state["S"]
198
+ d_proj = _newton_step(
199
+ var=var,
200
+ H=self.global_state["H_sketched"],
201
+ damping=self.defaults["damping"],
202
+ inner=self.children.get("inner", None),
203
+ H_tfm=self.defaults["H_tfm"],
204
+ eigval_fn=self.defaults["eigval_fn"],
205
+ use_lstsq=self.defaults["use_lstsq"],
206
+ g_proj = lambda g: S.T @ g
207
+ )
208
+ d = S @ d_proj
209
+ var.update = vec_to_tensors(d, var.params)
210
+
211
+ return var
212
+
213
+ def get_H(self, var=...):
214
+ eigval_fn = self.defaults["eigval_fn"]
215
+ H_sketched: torch.Tensor = self.global_state["H_sketched"]
216
+ S: torch.Tensor = self.global_state["S"]
217
+
218
+ if eigval_fn is not None:
219
+ try:
220
+ L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
221
+ L: torch.Tensor = eigval_fn(L)
222
+ H_sketched = Q @ L.diag_embed() @ Q.mH
223
+
224
+ except torch.linalg.LinAlgError:
225
+ pass
226
+
227
+ return Sketched(S, H_sketched)
@@ -10,34 +10,48 @@ class Wrap(Module):
10
10
  """
11
11
  Wraps a pytorch optimizer to use it as a module.
12
12
 
13
- .. note::
14
- Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.
13
+ Note:
14
+ Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
15
15
 
16
16
  Args:
17
17
  opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
18
- function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
19
- or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
18
+ function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
19
+ or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
20
20
  *args:
21
21
  **kwargs:
22
- Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.
22
+ Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
23
+ use_param_groups:
24
+ Whether to pass settings passed to Modular to the wrapped optimizer.
23
25
 
24
- Example:
25
- wrapping pytorch_optimizer.StableAdamW
26
+ Note that settings to the first parameter are used for all parameters,
27
+ so if you specified per-parameter settings, they will be ignored.
26
28
 
27
- .. code-block:: py
29
+ ### Example:
30
+ wrapping pytorch_optimizer.StableAdamW
28
31
 
29
- from pytorch_optimizer import StableAdamW
30
- opt = tz.Modular(
31
- model.parameters(),
32
- tz.m.Wrap(StableAdamW, lr=1),
33
- tz.m.Cautious(),
34
- tz.m.LR(1e-2)
35
- )
32
+ ```python
36
33
 
34
+ from pytorch_optimizer import StableAdamW
35
+ opt = tz.Modular(
36
+ model.parameters(),
37
+ tz.m.Wrap(StableAdamW, lr=1),
38
+ tz.m.Cautious(),
39
+ tz.m.LR(1e-2)
40
+ )
41
+ ```
37
42
 
38
43
  """
39
- def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
40
- super().__init__()
44
+
45
+ def __init__(
46
+ self,
47
+ opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
48
+ *args,
49
+ use_param_groups: bool = True,
50
+ **kwargs,
51
+ ):
52
+ defaults = dict(use_param_groups=use_param_groups)
53
+ super().__init__(defaults=defaults)
54
+
41
55
  self._opt_fn = opt_fn
42
56
  self._opt_args = args
43
57
  self._opt_kwargs = kwargs
@@ -48,7 +62,7 @@ class Wrap(Module):
48
62
  self.optimizer = self._opt_fn
49
63
 
50
64
  def set_param_groups(self, param_groups):
51
- self._custom_param_groups = param_groups
65
+ self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
52
66
  return super().set_param_groups(param_groups)
53
67
 
54
68
  @torch.no_grad
@@ -61,37 +75,29 @@ class Wrap(Module):
61
75
  param_groups = params if self._custom_param_groups is None else self._custom_param_groups
62
76
  self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
63
77
 
78
+ # set optimizer per-parameter settings
79
+ if self.defaults["use_param_groups"] and var.modular is not None:
80
+ for group in self.optimizer.param_groups:
81
+ first_param = group['params'][0]
82
+ setting = self.settings[first_param]
83
+
84
+ # settings passed in `set_param_groups` are the highest priority
85
+ # schedulers will override defaults but not settings passed in `set_param_groups`
86
+ # this is consistent with how Modular does it.
87
+ if self._custom_param_groups is not None:
88
+ setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
89
+
90
+ group.update(setting)
91
+
64
92
  # set grad to update
65
93
  orig_grad = [p.grad for p in params]
66
94
  for p, u in zip(params, var.get_update()):
67
95
  p.grad = u
68
96
 
69
- # if this module is last, can step with _opt directly
70
- # direct step can't be applied if next module is LR but _opt doesn't support lr,
71
- # and if there are multiple different per-parameter lrs (would be annoying to support)
72
- if var.is_last and (
73
- (var.last_module_lrs is None)
74
- or
75
- (('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
76
- ):
77
- lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
78
-
79
- # update optimizer lr with desired lr
80
- if lr != 1:
81
- self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
82
- for g in self.optimizer.param_groups:
83
- g['__original_lr__'] = g['lr']
84
- g['lr'] = g['lr'] * lr
85
-
86
- # step
97
+ # if this is last module, simply use optimizer to update parameters
98
+ if var.modular is not None and self is var.modular.modules[-1]:
87
99
  self.optimizer.step()
88
100
 
89
- # restore original lr
90
- if lr != 1:
91
- self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
92
- for g in self.optimizer.param_groups:
93
- g['lr'] = g.pop('__original_lr__')
94
-
95
101
  # restore grad
96
102
  for p, g in zip(params, orig_grad):
97
103
  p.grad = g
@@ -100,6 +106,7 @@ class Wrap(Module):
100
106
  return var
101
107
 
102
108
  # this is not the last module, meaning update is difference in parameters
109
+ # and passed to next module
103
110
  params_before_step = [p.clone() for p in params]
104
111
  self.optimizer.step() # step and update params
105
112
  for p, g in zip(params, orig_grad):
@@ -5,13 +5,13 @@ import torch.autograd.forward_ad as fwAD
5
5
 
6
6
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
7
7
 
8
- def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
9
- flat_input = torch.cat([i.reshape(-1) for i in output])
10
- grad_ouputs = torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype)
8
+ def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
9
+ flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
10
+ grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
11
11
  jac = []
12
- for i in range(flat_input.numel()):
12
+ for i in range(flat_outputs.numel()):
13
13
  jac.append(torch.autograd.grad(
14
- flat_input,
14
+ flat_outputs,
15
15
  wrt,
16
16
  grad_ouputs[i],
17
17
  retain_graph=True,
@@ -22,12 +22,12 @@ def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], creat
22
22
  return [torch.stack(z) for z in zip(*jac)]
23
23
 
24
24
 
25
- def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
26
- flat_input = torch.cat([i.reshape(-1) for i in output])
25
+ def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
26
+ flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
27
27
  return torch.autograd.grad(
28
- flat_input,
28
+ flat_outputs,
29
29
  wrt,
30
- torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype),
30
+ torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype),
31
31
  retain_graph=True,
32
32
  create_graph=create_graph,
33
33
  allow_unused=True,
@@ -51,13 +51,13 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
51
51
  return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
52
52
 
53
53
 
54
- def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
54
+ def jacobian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
55
55
  """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
56
56
  Returns a sequence of tensors with the length as `wrt`.
57
57
  Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
58
58
 
59
59
  Args:
60
- input (Sequence[torch.Tensor]): input sequence of tensors.
60
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
61
61
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
62
62
  create_graph (bool, optional):
63
63
  pytorch option, if True, graph of the derivative will be constructed,
@@ -68,16 +68,16 @@ def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], cr
68
68
  Returns:
69
69
  sequence of tensors with the length as `wrt`.
70
70
  """
71
- if batched: return _jacobian_batched(output, wrt, create_graph)
72
- return _jacobian(output, wrt, create_graph)
71
+ if batched: return _jacobian_batched(outputs, wrt, create_graph)
72
+ return _jacobian(outputs, wrt, create_graph)
73
73
 
74
- def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
74
+ def jacobian_and_hessian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
75
75
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
76
76
  Calculating hessian requires calculating the jacobian. So this function is more efficient than
77
77
  calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
78
78
 
79
79
  Args:
80
- input (Sequence[torch.Tensor]): input sequence of tensors.
80
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
81
81
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
82
82
  create_graph (bool, optional):
83
83
  pytorch option, if True, graph of the derivative will be constructed,
@@ -87,7 +87,7 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
87
87
  Returns:
88
88
  tuple with jacobians sequence and hessians sequence.
89
89
  """
90
- jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
90
+ jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
91
91
  return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
92
92
 
93
93
 
@@ -96,13 +96,13 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
96
96
  # Note - I only tested this for cases where input is a scalar."""
97
97
  # return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
98
98
 
99
- def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
99
+ def jacobian_and_hessian_mat_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
100
100
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
101
101
  Calculating hessian requires calculating the jacobian. So this function is more efficient than
102
102
  calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
103
103
 
104
104
  Args:
105
- input (Sequence[torch.Tensor]): input sequence of tensors.
105
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
106
106
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
107
107
  create_graph (bool, optional):
108
108
  pytorch option, if True, graph of the derivative will be constructed,
@@ -112,7 +112,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
112
112
  Returns:
113
113
  tuple with jacobians sequence and hessians sequence.
114
114
  """
115
- jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
115
+ jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
116
116
  H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
117
117
  return flatten_jacobian(jac), flatten_jacobian(H_list)
118
118
 
@@ -35,8 +35,8 @@ class LinearOperator(ABC):
35
35
  """solve with a norm bound on x"""
36
36
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
37
37
 
38
- def update(self, *args, **kwargs) -> None:
39
- raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
38
+ # def update(self, *args, **kwargs) -> None:
39
+ # raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
40
40
 
41
41
  def add(self, x: torch.Tensor) -> "LinearOperator":
42
42
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
@@ -298,6 +298,7 @@ class AtA(LinearOperator):
298
298
  class AAT(LinearOperator):
299
299
  def __init__(self, A: torch.Tensor):
300
300
  self.A = A
301
+ self.device = self.A.device; self.dtype = self.A.dtype
301
302
 
302
303
  def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
303
304
  def rmatvec(self, x): return self.matvec(x)
@@ -327,3 +328,50 @@ class AAT(LinearOperator):
327
328
  n = self.A.size(1)
328
329
  return (n,n)
329
330
 
331
+
332
+ class Sketched(LinearOperator):
333
+ """A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
334
+
335
+ Where A is (n, n) and S is (n, sketch_size).
336
+ """
337
+ def __init__(self, S: torch.Tensor, A_proj: torch.Tensor):
338
+ self.S = S
339
+ self.A_proj = A_proj
340
+ self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
341
+
342
+
343
+ def matvec(self, x):
344
+ x_proj = self.S.T @ x
345
+ Ax_proj = self.A_proj @ x_proj
346
+ return self.S @ Ax_proj
347
+
348
+ def rmatvec(self, x):
349
+ x_proj = self.S.T @ x
350
+ ATx_proj = self.A_proj.mH @ x_proj
351
+ return self.S @ ATx_proj
352
+
353
+
354
+ def matmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, x])) # pylint:disable=not-callable
355
+ def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, x])) # pylint:disable=not-callable
356
+
357
+
358
+ def is_dense(self): return False
359
+ def to_tensor(self): return self.S @ self.A_proj @ self.S.T
360
+ def transpose(self): return Sketched(self.S, self.A_proj.mH)
361
+
362
+ def add_diagonal(self, x):
363
+ """this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
364
+ if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
365
+ if isinstance(x, (int,float)): x = torch.full((self.A_proj.shape[0],), fill_value=x, device=self.A_proj.device, dtype=self.A_proj.dtype)
366
+ return Sketched(S=self.S, A_proj=self.A_proj + x.diag_embed())
367
+
368
+ def solve(self, b):
369
+ return self.S @ torch.linalg.lstsq(self.A_proj, self.S.T @ b).solution # pylint:disable=not-callable
370
+
371
+ def inv(self):
372
+ return Sketched(S=self.S, A_proj=torch.linalg.pinv(self.A_proj)) # pylint:disable=not-callable
373
+
374
+ def size(self):
375
+ n = self.S.size(0)
376
+ return (n,n)
377
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.14
3
+ Version: 0.3.15
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero