torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -75,8 +75,6 @@ class NLOptWrapper(Optimizer):
75
75
  so usually you would want to perform a single step, although performing multiple steps will refine the
76
76
  solution.
77
77
 
78
- Some algorithms are buggy with numpy>=2.
79
-
80
78
  Args:
81
79
  params: iterable of parameters to optimize or dicts defining parameter groups.
82
80
  algorithm (int | _ALGOS_LITERAL): optimization algorithm from https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  import optuna
8
8
 
9
- from ...utils import Optimizer
9
+ from ...utils import Optimizer, totensor, tofloat
10
10
 
11
11
  def silence_optuna():
12
12
  optuna.logging.set_verbosity(optuna.logging.WARNING)
@@ -65,6 +65,6 @@ class OptunaSampler(Optimizer):
65
65
  params.from_vec_(vec)
66
66
 
67
67
  loss = closure()
68
- with torch.enable_grad(): self.study.tell(trial, loss)
68
+ with torch.enable_grad(): self.study.tell(trial, tofloat(torch.nan_to_num(totensor(loss), 1e32)))
69
69
 
70
70
  return loss
@@ -4,12 +4,17 @@ from functools import partial
4
4
  from typing import Any, Literal
5
5
 
6
6
  import numpy as np
7
- import scipy.optimize
8
7
  import torch
9
8
 
9
+ import scipy.optimize
10
+
10
11
  from ...utils import Optimizer, TensorList
11
- from ...utils.derivatives import jacobian_and_hessian_mat_wrt, jacobian_wrt
12
- from ...modules.second_order.newton import tikhonov_
12
+ from ...utils.derivatives import (
13
+ flatten_jacobian,
14
+ jacobian_and_hessian_mat_wrt,
15
+ jacobian_wrt,
16
+ )
17
+
13
18
 
14
19
  def _ensure_float(x) -> float:
15
20
  if isinstance(x, torch.Tensor): return x.detach().cpu().item()
@@ -21,14 +26,6 @@ def _ensure_numpy(x):
21
26
  if isinstance(x, np.ndarray): return x
22
27
  return np.array(x)
23
28
 
24
- def matrix_clamp(H: torch.Tensor, reg: float):
25
- try:
26
- eigvals, eigvecs = torch.linalg.eigh(H) # pylint:disable=not-callable
27
- eigvals.clamp_(min=reg)
28
- return eigvecs @ torch.diag(eigvals) @ eigvecs.mH
29
- except Exception:
30
- return H
31
-
32
29
  Closure = Callable[[bool], Any]
33
30
 
34
31
  class ScipyMinimize(Optimizer):
@@ -76,8 +73,6 @@ class ScipyMinimize(Optimizer):
76
73
  options = None,
77
74
  jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
78
75
  hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
79
- tikhonov: float | None = 0,
80
- min_eigval: float | None = None,
81
76
  ):
82
77
  defaults = dict(lb=lb, ub=ub)
83
78
  super().__init__(params, defaults)
@@ -85,12 +80,10 @@ class ScipyMinimize(Optimizer):
85
80
  self.constraints = constraints
86
81
  self.tol = tol
87
82
  self.callback = callback
88
- self.min_eigval = min_eigval
89
83
  self.options = options
90
84
 
91
85
  self.jac = jac
92
86
  self.hess = hess
93
- self.tikhonov: float | None = tikhonov
94
87
 
95
88
  self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
96
89
  'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
@@ -111,9 +104,7 @@ class ScipyMinimize(Optimizer):
111
104
  with torch.enable_grad():
112
105
  value = closure(False)
113
106
  _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
114
- if self.tikhonov is not None: H = tikhonov_(H, self.tikhonov)
115
- if self.min_eigval is not None: H = matrix_clamp(H, self.min_eigval)
116
- return H.detach().cpu().numpy()
107
+ return H.numpy(force=True)
117
108
 
118
109
  def _objective(self, x: np.ndarray, params: TensorList, closure):
119
110
  # set params to x
@@ -122,7 +113,10 @@ class ScipyMinimize(Optimizer):
122
113
  # return value and maybe gradients
123
114
  if self.use_jac_autograd:
124
115
  with torch.enable_grad(): value = _ensure_float(closure())
125
- return value, params.ensure_grad_().grad.to_vec().detach().cpu().numpy()
116
+ grad = params.ensure_grad_().grad.to_vec().numpy(force=True)
117
+ # slsqp requires float64
118
+ if self.method.lower() == 'slsqp': grad = grad.astype(np.float64)
119
+ return value, grad
126
120
  return _ensure_float(closure(False))
127
121
 
128
122
  @torch.no_grad
@@ -135,7 +129,7 @@ class ScipyMinimize(Optimizer):
135
129
  else: hess = None
136
130
  else: hess = self.hess
137
131
 
138
- x0 = params.to_vec().detach().cpu().numpy()
132
+ x0 = params.to_vec().numpy(force=True)
139
133
 
140
134
  # make bounds
141
135
  lb, ub = self.group_vals('lb', 'ub', cls=list)
@@ -167,7 +161,7 @@ class ScipyMinimize(Optimizer):
167
161
 
168
162
 
169
163
  class ScipyRootOptimization(Optimizer):
170
- """Optimization via using scipy.root on gradients, mainly for experimenting!
164
+ """Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
171
165
 
172
166
  Args:
173
167
  params: iterable of parameters to optimize or dicts defining parameter groups.
@@ -248,6 +242,72 @@ class ScipyRootOptimization(Optimizer):
248
242
  return res.fun
249
243
 
250
244
 
245
+ class ScipyLeastSquaresOptimization(Optimizer):
246
+ """Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
247
+
248
+ Args:
249
+ params: iterable of parameters to optimize or dicts defining parameter groups.
250
+ method (str | None, optional): _description_. Defaults to None.
251
+ tol (float | None, optional): _description_. Defaults to None.
252
+ callback (_type_, optional): _description_. Defaults to None.
253
+ options (_type_, optional): _description_. Defaults to None.
254
+ jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
255
+ """
256
+ def __init__(
257
+ self,
258
+ params,
259
+ method='trf',
260
+ jac='autograd',
261
+ bounds=(-np.inf, np.inf),
262
+ ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
263
+ f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
264
+ jac_sparsity=None, max_nfev=None, verbose=0
265
+ ):
266
+ super().__init__(params, {})
267
+ kwargs = locals().copy()
268
+ del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
269
+ self._kwargs = kwargs
270
+
271
+ self.jac = jac
272
+
273
+
274
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
275
+ # set params to x
276
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
277
+
278
+ # return the gradients
279
+ with torch.enable_grad(): self.value = closure()
280
+ jac = params.ensure_grad_().grad.to_vec()
281
+ return jac.numpy(force=True)
282
+
283
+ def _hess(self, x: np.ndarray, params: TensorList, closure):
284
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
285
+ with torch.enable_grad():
286
+ value = closure(False)
287
+ _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
288
+ return H.numpy(force=True)
289
+
290
+ @torch.no_grad
291
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
292
+ params = self.get_params()
293
+
294
+ x0 = params.to_vec().detach().cpu().numpy()
295
+
296
+ if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
297
+ else: jac = self.jac
298
+
299
+ res = scipy.optimize.least_squares(
300
+ partial(self._objective, params = params, closure = closure),
301
+ x0 = x0,
302
+ jac=jac, # type:ignore
303
+ **self._kwargs
304
+ )
305
+
306
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
307
+ return res.fun
308
+
309
+
310
+
251
311
 
252
312
  class ScipyDE(Optimizer):
253
313
  """Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
@@ -510,4 +570,3 @@ class ScipyBrute(Optimizer):
510
570
  **self._kwargs
511
571
  )
512
572
  params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
513
- return None
@@ -1,5 +1,11 @@
1
1
  from . import tensorlist as tl
2
- from .compile import _optional_compiler, benchmark_compile_cpu, benchmark_compile_cuda, set_compilation, enable_compilation
2
+ from .compile import (
3
+ _optional_compiler,
4
+ benchmark_compile_cpu,
5
+ benchmark_compile_cuda,
6
+ enable_compilation,
7
+ set_compilation,
8
+ )
3
9
  from .numberlist import NumberList
4
10
  from .optimizer import (
5
11
  Init,
@@ -18,6 +24,36 @@ from .params import (
18
24
  _copy_param_groups,
19
25
  _make_param_groups,
20
26
  )
21
- from .python_tools import flatten, generic_eq, generic_ne, reduce_dim, unpack_dicts
22
- from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like, generic_finfo_eps
23
- from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_
27
+ from .python_tools import (
28
+ flatten,
29
+ generic_eq,
30
+ generic_ne,
31
+ reduce_dim,
32
+ safe_dict_update_,
33
+ unpack_dicts,
34
+ )
35
+ from .tensorlist import (
36
+ Distributions,
37
+ Metrics,
38
+ TensorList,
39
+ as_tensorlist,
40
+ generic_clamp,
41
+ generic_finfo,
42
+ generic_finfo_eps,
43
+ generic_finfo_tiny,
44
+ generic_max,
45
+ generic_numel,
46
+ generic_randn_like,
47
+ generic_sum,
48
+ generic_vector_norm,
49
+ generic_zeros_like,
50
+ )
51
+ from .torch_tools import (
52
+ set_storage_,
53
+ tofloat,
54
+ tolist,
55
+ tonumpy,
56
+ totensor,
57
+ vec_to_tensors,
58
+ vec_to_tensors_,
59
+ )
@@ -38,7 +38,7 @@ class _MaybeCompiledFunc:
38
38
  _optional_compiler = _OptionalCompiler()
39
39
  """this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
40
40
 
41
- def set_compilation(enable: bool):
41
+ def set_compilation(enable: bool=True):
42
42
  """`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
43
43
  _optional_compiler.enable = enable
44
44
 
@@ -2,7 +2,6 @@ from collections.abc import Iterable, Sequence
2
2
 
3
3
  import torch
4
4
  import torch.autograd.forward_ad as fwAD
5
- from typing import Literal
6
5
 
7
6
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
8
7
 
@@ -35,10 +34,27 @@ def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor
35
34
  is_grads_batched=True,
36
35
  )
37
36
 
37
+ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
38
+ """Converts the output of jacobian_wrt (a list of tensors) into a single 2D matrix.
39
+
40
+ Args:
41
+ jacs (Sequence[torch.Tensor]):
42
+ output from jacobian_wrt where ach tensor has the shape `(*output.shape, *wrt[i].shape)`.
43
+
44
+ Returns:
45
+ torch.Tensor: has the shape `(output.ndim, wrt.ndim)`.
46
+ """
47
+ if not jacs:
48
+ return torch.empty(0, 0)
49
+
50
+ n_out = jacs[0].shape[0]
51
+ return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
52
+
53
+
38
54
  def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
39
55
  """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
40
56
  Returns a sequence of tensors with the length as `wrt`.
41
- Each tensor will have the shape `(*input.shape, *wrt[i].shape)`.
57
+ Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
42
58
 
43
59
  Args:
44
60
  input (Sequence[torch.Tensor]): input sequence of tensors.
@@ -75,10 +91,10 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
75
91
  return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
76
92
 
77
93
 
78
- def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
79
- """takes output of `hessian` and returns the 2D hessian matrix.
80
- Note - I only tested this for cases where input is a scalar."""
81
- return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
94
+ # def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
95
+ # """takes output of `hessian` and returns the 2D hessian matrix.
96
+ # Note - I only tested this for cases where input is a scalar."""
97
+ # return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
82
98
 
83
99
  def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
84
100
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
@@ -98,7 +114,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
98
114
  """
99
115
  jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
100
116
  H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
101
- return torch.cat([j.view(-1) for j in jac]), hessian_list_to_mat(H_list)
117
+ return flatten_jacobian(jac), flatten_jacobian(H_list)
102
118
 
103
119
  def hessian(
104
120
  fn,
@@ -115,19 +131,18 @@ def hessian(
115
131
  `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
116
132
 
117
133
  Example:
118
- .. code:: py
119
-
120
- model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
121
- X = torch.randn(10, 4)
122
- y = torch.randn(10, 2)
134
+ ```python
135
+ model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
136
+ X = torch.randn(10, 4)
137
+ y = torch.randn(10, 2)
123
138
 
124
- def fn():
125
- y_hat = model(X)
126
- loss = F.mse_loss(y_hat, y)
127
- return loss
128
-
129
- hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
139
+ def fn():
140
+ y_hat = model(X)
141
+ loss = F.mse_loss(y_hat, y)
142
+ return loss
130
143
 
144
+ hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
145
+ ```
131
146
 
132
147
  """
133
148
  params = list(params)
@@ -165,19 +180,18 @@ def hessian_mat(
165
180
  `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
166
181
 
167
182
  Example:
168
- .. code:: py
169
-
170
- model = nn.Linear(4, 2) # 10 parameters in total
171
- X = torch.randn(10, 4)
172
- y = torch.randn(10, 2)
183
+ ```python
184
+ model = nn.Linear(4, 2) # 10 parameters in total
185
+ X = torch.randn(10, 4)
186
+ y = torch.randn(10, 2)
173
187
 
174
- def fn():
175
- y_hat = model(X)
176
- loss = F.mse_loss(y_hat, y)
177
- return loss
178
-
179
- hessian_mat(fn, model.parameters()) # 10x10 tensor
188
+ def fn():
189
+ y_hat = model(X)
190
+ loss = F.mse_loss(y_hat, y)
191
+ return loss
180
192
 
193
+ hessian_mat(fn, model.parameters()) # 10x10 tensor
194
+ ```
181
195
 
182
196
  """
183
197
  params = list(params)
@@ -206,21 +220,20 @@ def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) ->
206
220
  """Jacobian vector product.
207
221
 
208
222
  Example:
209
- .. code:: py
210
-
211
- model = nn.Linear(4, 2)
212
- X = torch.randn(10, 4)
213
- y = torch.randn(10, 2)
214
-
215
- tangent = [torch.randn_like(p) for p in model.parameters()]
223
+ ```python
224
+ model = nn.Linear(4, 2)
225
+ X = torch.randn(10, 4)
226
+ y = torch.randn(10, 2)
216
227
 
217
- def fn():
218
- y_hat = model(X)
219
- loss = F.mse_loss(y_hat, y)
220
- return loss
228
+ tangent = [torch.randn_like(p) for p in model.parameters()]
221
229
 
222
- jvp(fn, model.parameters(), tangent) # scalar
230
+ def fn():
231
+ y_hat = model(X)
232
+ loss = F.mse_loss(y_hat, y)
233
+ return loss
223
234
 
235
+ jvp(fn, model.parameters(), tangent) # scalar
236
+ ```
224
237
  """
225
238
  params = list(params)
226
239
  tangent = list(tangent)
@@ -253,21 +266,20 @@ def jvp_fd_central(
253
266
  """Jacobian vector product using central finite difference formula.
254
267
 
255
268
  Example:
256
- .. code:: py
257
-
258
- model = nn.Linear(4, 2)
259
- X = torch.randn(10, 4)
260
- y = torch.randn(10, 2)
261
-
262
- tangent = [torch.randn_like(p) for p in model.parameters()]
269
+ ```python
270
+ model = nn.Linear(4, 2)
271
+ X = torch.randn(10, 4)
272
+ y = torch.randn(10, 2)
263
273
 
264
- def fn():
265
- y_hat = model(X)
266
- loss = F.mse_loss(y_hat, y)
267
- return loss
274
+ tangent = [torch.randn_like(p) for p in model.parameters()]
268
275
 
269
- jvp_fd_central(fn, model.parameters(), tangent) # scalar
276
+ def fn():
277
+ y_hat = model(X)
278
+ loss = F.mse_loss(y_hat, y)
279
+ return loss
270
280
 
281
+ jvp_fd_central(fn, model.parameters(), tangent) # scalar
282
+ ```
271
283
  """
272
284
  params = list(params)
273
285
  tangent = list(tangent)
@@ -304,24 +316,24 @@ def jvp_fd_forward(
304
316
  Loss at initial point can be specified in the `v_0` argument.
305
317
 
306
318
  Example:
307
- .. code:: py
319
+ ```python
320
+ model = nn.Linear(4, 2)
321
+ X = torch.randn(10, 4)
322
+ y = torch.randn(10, 2)
308
323
 
309
- model = nn.Linear(4, 2)
310
- X = torch.randn(10, 4)
311
- y = torch.randn(10, 2)
324
+ tangent1 = [torch.randn_like(p) for p in model.parameters()]
325
+ tangent2 = [torch.randn_like(p) for p in model.parameters()]
312
326
 
313
- tangent1 = [torch.randn_like(p) for p in model.parameters()]
314
- tangent2 = [torch.randn_like(p) for p in model.parameters()]
315
-
316
- def fn():
317
- y_hat = model(X)
318
- loss = F.mse_loss(y_hat, y)
319
- return loss
327
+ def fn():
328
+ y_hat = model(X)
329
+ loss = F.mse_loss(y_hat, y)
330
+ return loss
320
331
 
321
- v_0 = fn() # pre-calculate loss at initial point
332
+ v_0 = fn() # pre-calculate loss at initial point
322
333
 
323
- jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
324
- jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
334
+ jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
335
+ jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
336
+ ```
325
337
 
326
338
  """
327
339
  params = list(params)
@@ -356,21 +368,21 @@ def hvp(
356
368
  """Hessian-vector product
357
369
 
358
370
  Example:
359
- .. code:: py
360
-
361
- model = nn.Linear(4, 2)
362
- X = torch.randn(10, 4)
363
- y = torch.randn(10, 2)
371
+ ```python
372
+ model = nn.Linear(4, 2)
373
+ X = torch.randn(10, 4)
374
+ y = torch.randn(10, 2)
364
375
 
365
- y_hat = model(X)
366
- loss = F.mse_loss(y_hat, y)
367
- loss.backward(create_graph=True)
376
+ y_hat = model(X)
377
+ loss = F.mse_loss(y_hat, y)
378
+ loss.backward(create_graph=True)
368
379
 
369
- grads = [p.grad for p in model.parameters()]
370
- vec = [torch.randn_like(p) for p in model.parameters()]
380
+ grads = [p.grad for p in model.parameters()]
381
+ vec = [torch.randn_like(p) for p in model.parameters()]
371
382
 
372
- # list of tensors, same layout as model.parameters()
373
- hvp(model.parameters(), grads, vec=vec)
383
+ # list of tensors, same layout as model.parameters()
384
+ hvp(model.parameters(), grads, vec=vec)
385
+ ```
374
386
  """
375
387
  params = list(params)
376
388
  g = list(grads)
@@ -393,23 +405,23 @@ def hvp_fd_central(
393
405
  Please note that this will clear :code:`grad` attributes in params.
394
406
 
395
407
  Example:
396
- .. code:: py
397
-
398
- model = nn.Linear(4, 2)
399
- X = torch.randn(10, 4)
400
- y = torch.randn(10, 2)
408
+ ```python
409
+ model = nn.Linear(4, 2)
410
+ X = torch.randn(10, 4)
411
+ y = torch.randn(10, 2)
401
412
 
402
- def closure():
403
- y_hat = model(X)
404
- loss = F.mse_loss(y_hat, y)
405
- model.zero_grad()
406
- loss.backward()
407
- return loss
413
+ def closure():
414
+ y_hat = model(X)
415
+ loss = F.mse_loss(y_hat, y)
416
+ model.zero_grad()
417
+ loss.backward()
418
+ return loss
408
419
 
409
- vec = [torch.randn_like(p) for p in model.parameters()]
420
+ vec = [torch.randn_like(p) for p in model.parameters()]
410
421
 
411
- # list of tensors, same layout as model.parameters()
412
- hvp_fd_central(closure, model.parameters(), vec=vec)
422
+ # list of tensors, same layout as model.parameters()
423
+ hvp_fd_central(closure, model.parameters(), vec=vec)
424
+ ```
413
425
  """
414
426
  params = list(params)
415
427
  vec = list(vec)
@@ -456,27 +468,27 @@ def hvp_fd_forward(
456
468
  Please note that this will clear :code:`grad` attributes in params.
457
469
 
458
470
  Example:
459
- .. code:: py
471
+ ```python
472
+ model = nn.Linear(4, 2)
473
+ X = torch.randn(10, 4)
474
+ y = torch.randn(10, 2)
460
475
 
461
- model = nn.Linear(4, 2)
462
- X = torch.randn(10, 4)
463
- y = torch.randn(10, 2)
464
-
465
- def closure():
466
- y_hat = model(X)
467
- loss = F.mse_loss(y_hat, y)
468
- model.zero_grad()
469
- loss.backward()
470
- return loss
476
+ def closure():
477
+ y_hat = model(X)
478
+ loss = F.mse_loss(y_hat, y)
479
+ model.zero_grad()
480
+ loss.backward()
481
+ return loss
471
482
 
472
- vec = [torch.randn_like(p) for p in model.parameters()]
483
+ vec = [torch.randn_like(p) for p in model.parameters()]
473
484
 
474
- # pre-compute gradient at initial point
475
- closure()
476
- g_0 = [p.grad for p in model.parameters()]
485
+ # pre-compute gradient at initial point
486
+ closure()
487
+ g_0 = [p.grad for p in model.parameters()]
477
488
 
478
- # list of tensors, same layout as model.parameters()
479
- hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
489
+ # list of tensors, same layout as model.parameters()
490
+ hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
491
+ ```
480
492
  """
481
493
 
482
494
  params = list(params)
@@ -485,7 +497,7 @@ def hvp_fd_forward(
485
497
 
486
498
  vec_norm = None
487
499
  if normalize:
488
- vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
500
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in vec])) # pylint:disable=not-callable
489
501
  if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
490
502
  vec = torch._foreach_div(vec, vec_norm)
491
503
 
@@ -1,5 +1,12 @@
1
- from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix_power_eigh, x_inv
1
+ from . import linear_operator
2
+ from .matrix_funcs import (
3
+ eigvals_func,
4
+ inv_sqrt_2x2,
5
+ matrix_power_eigh,
6
+ singular_vals_func,
7
+ x_inv,
8
+ )
2
9
  from .orthogonalize import gram_schmidt
3
10
  from .qr import qr_householder
11
+ from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
4
12
  from .svd import randomized_svd
5
- from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve, steihaug_toint_cg