torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  import optuna
8
8
 
9
- from ...utils import Optimizer
9
+ from ...utils import Optimizer, totensor, tofloat
10
10
 
11
11
  def silence_optuna():
12
12
  optuna.logging.set_verbosity(optuna.logging.WARNING)
@@ -23,7 +23,7 @@ class OptunaSampler(Optimizer):
23
23
  Note - optuna is surprisingly scalable to large number of parameters (up to 10,000), despite literally requiring a for-loop because it only supports scalars. Default TPESampler is good for BBO. Maybe not for NNs...
24
24
 
25
25
  Args:
26
- params (_type_): parameters
26
+ params: iterable of parameters to optimize or dicts defining parameter groups.
27
27
  lb (float): lower bounds.
28
28
  ub (float): upper bounds.
29
29
  sampler (optuna.samplers.BaseSampler | type[optuna.samplers.BaseSampler] | None, optional): sampler. Defaults to None.
@@ -65,6 +65,6 @@ class OptunaSampler(Optimizer):
65
65
  params.from_vec_(vec)
66
66
 
67
67
  loss = closure()
68
- with torch.enable_grad(): self.study.tell(trial, loss)
68
+ with torch.enable_grad(): self.study.tell(trial, tofloat(torch.nan_to_num(totensor(loss), 1e32)))
69
69
 
70
70
  return loss
@@ -4,12 +4,17 @@ from functools import partial
4
4
  from typing import Any, Literal
5
5
 
6
6
  import numpy as np
7
- import scipy.optimize
8
7
  import torch
9
8
 
9
+ import scipy.optimize
10
+
10
11
  from ...utils import Optimizer, TensorList
11
- from ...utils.derivatives import jacobian_and_hessian_mat_wrt, jacobian_wrt
12
- from ...modules.second_order.newton import tikhonov_
12
+ from ...utils.derivatives import (
13
+ flatten_jacobian,
14
+ jacobian_and_hessian_mat_wrt,
15
+ jacobian_wrt,
16
+ )
17
+
13
18
 
14
19
  def _ensure_float(x) -> float:
15
20
  if isinstance(x, torch.Tensor): return x.detach().cpu().item()
@@ -21,14 +26,6 @@ def _ensure_numpy(x):
21
26
  if isinstance(x, np.ndarray): return x
22
27
  return np.array(x)
23
28
 
24
- def matrix_clamp(H: torch.Tensor, reg: float):
25
- try:
26
- eigvals, eigvecs = torch.linalg.eigh(H) # pylint:disable=not-callable
27
- eigvals.clamp_(min=reg)
28
- return eigvecs @ torch.diag(eigvals) @ eigvecs.mH
29
- except Exception:
30
- return H
31
-
32
29
  Closure = Callable[[bool], Any]
33
30
 
34
31
  class ScipyMinimize(Optimizer):
@@ -76,8 +73,6 @@ class ScipyMinimize(Optimizer):
76
73
  options = None,
77
74
  jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
78
75
  hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
79
- tikhonov: float | None = 0,
80
- min_eigval: float | None = None,
81
76
  ):
82
77
  defaults = dict(lb=lb, ub=ub)
83
78
  super().__init__(params, defaults)
@@ -85,12 +80,10 @@ class ScipyMinimize(Optimizer):
85
80
  self.constraints = constraints
86
81
  self.tol = tol
87
82
  self.callback = callback
88
- self.min_eigval = min_eigval
89
83
  self.options = options
90
84
 
91
85
  self.jac = jac
92
86
  self.hess = hess
93
- self.tikhonov: float | None = tikhonov
94
87
 
95
88
  self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
96
89
  'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
@@ -111,9 +104,7 @@ class ScipyMinimize(Optimizer):
111
104
  with torch.enable_grad():
112
105
  value = closure(False)
113
106
  _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
114
- if self.tikhonov is not None: H = tikhonov_(H, self.tikhonov)
115
- if self.min_eigval is not None: H = matrix_clamp(H, self.min_eigval)
116
- return H.detach().cpu().numpy()
107
+ return H.numpy(force=True)
117
108
 
118
109
  def _objective(self, x: np.ndarray, params: TensorList, closure):
119
110
  # set params to x
@@ -122,7 +113,10 @@ class ScipyMinimize(Optimizer):
122
113
  # return value and maybe gradients
123
114
  if self.use_jac_autograd:
124
115
  with torch.enable_grad(): value = _ensure_float(closure())
125
- return value, params.ensure_grad_().grad.to_vec().detach().cpu().numpy()
116
+ grad = params.ensure_grad_().grad.to_vec().numpy(force=True)
117
+ # slsqp requires float64
118
+ if self.method.lower() == 'slsqp': grad = grad.astype(np.float64)
119
+ return value, grad
126
120
  return _ensure_float(closure(False))
127
121
 
128
122
  @torch.no_grad
@@ -135,13 +129,15 @@ class ScipyMinimize(Optimizer):
135
129
  else: hess = None
136
130
  else: hess = self.hess
137
131
 
138
- x0 = params.to_vec().detach().cpu().numpy()
132
+ x0 = params.to_vec().numpy(force=True)
139
133
 
140
134
  # make bounds
141
135
  lb, ub = self.group_vals('lb', 'ub', cls=list)
142
- bounds = []
143
- for p, l, u in zip(params, lb, ub):
144
- bounds.extend([(l, u)] * p.numel())
136
+ bounds = None
137
+ if any(b is not None for b in lb) or any(b is not None for b in ub):
138
+ bounds = []
139
+ for p, l, u in zip(params, lb, ub):
140
+ bounds.extend([(l, u)] * p.numel())
145
141
 
146
142
  if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
147
143
  x0 = x0.astype(np.float64) # those methods error without this
@@ -165,7 +161,7 @@ class ScipyMinimize(Optimizer):
165
161
 
166
162
 
167
163
  class ScipyRootOptimization(Optimizer):
168
- """Optimization via using scipy.root on gradients, mainly for experimenting!
164
+ """Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
169
165
 
170
166
  Args:
171
167
  params: iterable of parameters to optimize or dicts defining parameter groups.
@@ -246,6 +242,72 @@ class ScipyRootOptimization(Optimizer):
246
242
  return res.fun
247
243
 
248
244
 
245
+ class ScipyLeastSquaresOptimization(Optimizer):
246
+ """Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
247
+
248
+ Args:
249
+ params: iterable of parameters to optimize or dicts defining parameter groups.
250
+ method (str | None, optional): _description_. Defaults to None.
251
+ tol (float | None, optional): _description_. Defaults to None.
252
+ callback (_type_, optional): _description_. Defaults to None.
253
+ options (_type_, optional): _description_. Defaults to None.
254
+ jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
255
+ """
256
+ def __init__(
257
+ self,
258
+ params,
259
+ method='trf',
260
+ jac='autograd',
261
+ bounds=(-np.inf, np.inf),
262
+ ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
263
+ f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
264
+ jac_sparsity=None, max_nfev=None, verbose=0
265
+ ):
266
+ super().__init__(params, {})
267
+ kwargs = locals().copy()
268
+ del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
269
+ self._kwargs = kwargs
270
+
271
+ self.jac = jac
272
+
273
+
274
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
275
+ # set params to x
276
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
277
+
278
+ # return the gradients
279
+ with torch.enable_grad(): self.value = closure()
280
+ jac = params.ensure_grad_().grad.to_vec()
281
+ return jac.numpy(force=True)
282
+
283
+ def _hess(self, x: np.ndarray, params: TensorList, closure):
284
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
285
+ with torch.enable_grad():
286
+ value = closure(False)
287
+ _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
288
+ return H.numpy(force=True)
289
+
290
+ @torch.no_grad
291
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
292
+ params = self.get_params()
293
+
294
+ x0 = params.to_vec().detach().cpu().numpy()
295
+
296
+ if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
297
+ else: jac = self.jac
298
+
299
+ res = scipy.optimize.least_squares(
300
+ partial(self._objective, params = params, closure = closure),
301
+ x0 = x0,
302
+ jac=jac, # type:ignore
303
+ **self._kwargs
304
+ )
305
+
306
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
307
+ return res.fun
308
+
309
+
310
+
249
311
 
250
312
  class ScipyDE(Optimizer):
251
313
  """Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
@@ -508,4 +570,3 @@ class ScipyBrute(Optimizer):
508
570
  **self._kwargs
509
571
  )
510
572
  params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
511
- return None
@@ -1,5 +1,11 @@
1
1
  from . import tensorlist as tl
2
- from .compile import _optional_compiler, benchmark_compile_cpu, benchmark_compile_cuda, set_compilation, enable_compilation
2
+ from .compile import (
3
+ _optional_compiler,
4
+ benchmark_compile_cpu,
5
+ benchmark_compile_cuda,
6
+ enable_compilation,
7
+ set_compilation,
8
+ )
3
9
  from .numberlist import NumberList
4
10
  from .optimizer import (
5
11
  Init,
@@ -18,6 +24,36 @@ from .params import (
18
24
  _copy_param_groups,
19
25
  _make_param_groups,
20
26
  )
21
- from .python_tools import flatten, generic_eq, reduce_dim, unpack_dicts
22
- from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like
23
- from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_
27
+ from .python_tools import (
28
+ flatten,
29
+ generic_eq,
30
+ generic_ne,
31
+ reduce_dim,
32
+ safe_dict_update_,
33
+ unpack_dicts,
34
+ )
35
+ from .tensorlist import (
36
+ Distributions,
37
+ Metrics,
38
+ TensorList,
39
+ as_tensorlist,
40
+ generic_clamp,
41
+ generic_finfo,
42
+ generic_finfo_eps,
43
+ generic_finfo_tiny,
44
+ generic_max,
45
+ generic_numel,
46
+ generic_randn_like,
47
+ generic_sum,
48
+ generic_vector_norm,
49
+ generic_zeros_like,
50
+ )
51
+ from .torch_tools import (
52
+ set_storage_,
53
+ tofloat,
54
+ tolist,
55
+ tonumpy,
56
+ totensor,
57
+ vec_to_tensors,
58
+ vec_to_tensors_,
59
+ )
@@ -38,7 +38,7 @@ class _MaybeCompiledFunc:
38
38
  _optional_compiler = _OptionalCompiler()
39
39
  """this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
40
40
 
41
- def set_compilation(enable: bool):
41
+ def set_compilation(enable: bool=True):
42
42
  """`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
43
43
  _optional_compiler.enable = enable
44
44
 
@@ -2,7 +2,6 @@ from collections.abc import Iterable, Sequence
2
2
 
3
3
  import torch
4
4
  import torch.autograd.forward_ad as fwAD
5
- from typing import Literal
6
5
 
7
6
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
8
7
 
@@ -35,10 +34,27 @@ def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor
35
34
  is_grads_batched=True,
36
35
  )
37
36
 
37
+ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
38
+ """Converts the output of jacobian_wrt (a list of tensors) into a single 2D matrix.
39
+
40
+ Args:
41
+ jacs (Sequence[torch.Tensor]):
42
+ output from jacobian_wrt where ach tensor has the shape `(*output.shape, *wrt[i].shape)`.
43
+
44
+ Returns:
45
+ torch.Tensor: has the shape `(output.ndim, wrt.ndim)`.
46
+ """
47
+ if not jacs:
48
+ return torch.empty(0, 0)
49
+
50
+ n_out = jacs[0].shape[0]
51
+ return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
52
+
53
+
38
54
  def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
39
55
  """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
40
56
  Returns a sequence of tensors with the length as `wrt`.
41
- Each tensor will have the shape `(*input.shape, *wrt[i].shape)`.
57
+ Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
42
58
 
43
59
  Args:
44
60
  input (Sequence[torch.Tensor]): input sequence of tensors.
@@ -75,10 +91,10 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
75
91
  return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
76
92
 
77
93
 
78
- def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
79
- """takes output of `hessian` and returns the 2D hessian matrix.
80
- Note - I only tested this for cases where input is a scalar."""
81
- return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
94
+ # def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
95
+ # """takes output of `hessian` and returns the 2D hessian matrix.
96
+ # Note - I only tested this for cases where input is a scalar."""
97
+ # return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
82
98
 
83
99
  def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
84
100
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
@@ -98,7 +114,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
98
114
  """
99
115
  jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
100
116
  H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
101
- return torch.cat([j.view(-1) for j in jac]), hessian_list_to_mat(H_list)
117
+ return flatten_jacobian(jac), flatten_jacobian(H_list)
102
118
 
103
119
  def hessian(
104
120
  fn,
@@ -115,19 +131,18 @@ def hessian(
115
131
  `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
116
132
 
117
133
  Example:
118
- .. code:: py
119
-
120
- model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
121
- X = torch.randn(10, 4)
122
- y = torch.randn(10, 2)
134
+ ```python
135
+ model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
136
+ X = torch.randn(10, 4)
137
+ y = torch.randn(10, 2)
123
138
 
124
- def fn():
125
- y_hat = model(X)
126
- loss = F.mse_loss(y_hat, y)
127
- return loss
128
-
129
- hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
139
+ def fn():
140
+ y_hat = model(X)
141
+ loss = F.mse_loss(y_hat, y)
142
+ return loss
130
143
 
144
+ hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
145
+ ```
131
146
 
132
147
  """
133
148
  params = list(params)
@@ -158,26 +173,25 @@ def hessian_mat(
158
173
  method="func",
159
174
  vectorize=False,
160
175
  outer_jacobian_strategy="reverse-mode",
161
- ):
176
+ ) -> torch.Tensor:
162
177
  """
163
178
  returns hessian matrix for parameters (as if they were flattened and concatenated into a vector).
164
179
 
165
180
  `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
166
181
 
167
182
  Example:
168
- .. code:: py
169
-
170
- model = nn.Linear(4, 2) # 10 parameters in total
171
- X = torch.randn(10, 4)
172
- y = torch.randn(10, 2)
183
+ ```python
184
+ model = nn.Linear(4, 2) # 10 parameters in total
185
+ X = torch.randn(10, 4)
186
+ y = torch.randn(10, 2)
173
187
 
174
- def fn():
175
- y_hat = model(X)
176
- loss = F.mse_loss(y_hat, y)
177
- return loss
178
-
179
- hessian_mat(fn, model.parameters()) # 10x10 tensor
188
+ def fn():
189
+ y_hat = model(X)
190
+ loss = F.mse_loss(y_hat, y)
191
+ return loss
180
192
 
193
+ hessian_mat(fn, model.parameters()) # 10x10 tensor
194
+ ```
181
195
 
182
196
  """
183
197
  params = list(params)
@@ -190,7 +204,7 @@ def hessian_mat(
190
204
  return loss
191
205
 
192
206
  if method == 'func':
193
- return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph))
207
+ return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph)) # pyright:ignore[reportReturnType]
194
208
 
195
209
  if method == 'autograd.functional':
196
210
  return torch.autograd.functional.hessian(
@@ -199,28 +213,27 @@ def hessian_mat(
199
213
  create_graph=create_graph,
200
214
  vectorize=vectorize,
201
215
  outer_jacobian_strategy=outer_jacobian_strategy,
202
- )
216
+ ) # pyright:ignore[reportReturnType]
203
217
  raise ValueError(method)
204
218
 
205
219
  def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
206
220
  """Jacobian vector product.
207
221
 
208
222
  Example:
209
- .. code:: py
210
-
211
- model = nn.Linear(4, 2)
212
- X = torch.randn(10, 4)
213
- y = torch.randn(10, 2)
223
+ ```python
224
+ model = nn.Linear(4, 2)
225
+ X = torch.randn(10, 4)
226
+ y = torch.randn(10, 2)
214
227
 
215
- tangent = [torch.randn_like(p) for p in model.parameters()]
228
+ tangent = [torch.randn_like(p) for p in model.parameters()]
216
229
 
217
- def fn():
218
- y_hat = model(X)
219
- loss = F.mse_loss(y_hat, y)
220
- return loss
221
-
222
- jvp(fn, model.parameters(), tangent) # scalar
230
+ def fn():
231
+ y_hat = model(X)
232
+ loss = F.mse_loss(y_hat, y)
233
+ return loss
223
234
 
235
+ jvp(fn, model.parameters(), tangent) # scalar
236
+ ```
224
237
  """
225
238
  params = list(params)
226
239
  tangent = list(tangent)
@@ -253,21 +266,20 @@ def jvp_fd_central(
253
266
  """Jacobian vector product using central finite difference formula.
254
267
 
255
268
  Example:
256
- .. code:: py
257
-
258
- model = nn.Linear(4, 2)
259
- X = torch.randn(10, 4)
260
- y = torch.randn(10, 2)
269
+ ```python
270
+ model = nn.Linear(4, 2)
271
+ X = torch.randn(10, 4)
272
+ y = torch.randn(10, 2)
261
273
 
262
- tangent = [torch.randn_like(p) for p in model.parameters()]
274
+ tangent = [torch.randn_like(p) for p in model.parameters()]
263
275
 
264
- def fn():
265
- y_hat = model(X)
266
- loss = F.mse_loss(y_hat, y)
267
- return loss
268
-
269
- jvp_fd_central(fn, model.parameters(), tangent) # scalar
276
+ def fn():
277
+ y_hat = model(X)
278
+ loss = F.mse_loss(y_hat, y)
279
+ return loss
270
280
 
281
+ jvp_fd_central(fn, model.parameters(), tangent) # scalar
282
+ ```
271
283
  """
272
284
  params = list(params)
273
285
  tangent = list(tangent)
@@ -304,24 +316,24 @@ def jvp_fd_forward(
304
316
  Loss at initial point can be specified in the `v_0` argument.
305
317
 
306
318
  Example:
307
- .. code:: py
308
-
309
- model = nn.Linear(4, 2)
310
- X = torch.randn(10, 4)
311
- y = torch.randn(10, 2)
319
+ ```python
320
+ model = nn.Linear(4, 2)
321
+ X = torch.randn(10, 4)
322
+ y = torch.randn(10, 2)
312
323
 
313
- tangent1 = [torch.randn_like(p) for p in model.parameters()]
314
- tangent2 = [torch.randn_like(p) for p in model.parameters()]
324
+ tangent1 = [torch.randn_like(p) for p in model.parameters()]
325
+ tangent2 = [torch.randn_like(p) for p in model.parameters()]
315
326
 
316
- def fn():
317
- y_hat = model(X)
318
- loss = F.mse_loss(y_hat, y)
319
- return loss
327
+ def fn():
328
+ y_hat = model(X)
329
+ loss = F.mse_loss(y_hat, y)
330
+ return loss
320
331
 
321
- v_0 = fn() # pre-calculate loss at initial point
332
+ v_0 = fn() # pre-calculate loss at initial point
322
333
 
323
- jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
324
- jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
334
+ jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
335
+ jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
336
+ ```
325
337
 
326
338
  """
327
339
  params = list(params)
@@ -356,21 +368,21 @@ def hvp(
356
368
  """Hessian-vector product
357
369
 
358
370
  Example:
359
- .. code:: py
371
+ ```python
372
+ model = nn.Linear(4, 2)
373
+ X = torch.randn(10, 4)
374
+ y = torch.randn(10, 2)
360
375
 
361
- model = nn.Linear(4, 2)
362
- X = torch.randn(10, 4)
363
- y = torch.randn(10, 2)
376
+ y_hat = model(X)
377
+ loss = F.mse_loss(y_hat, y)
378
+ loss.backward(create_graph=True)
364
379
 
365
- y_hat = model(X)
366
- loss = F.mse_loss(y_hat, y)
367
- loss.backward(create_graph=True)
368
-
369
- grads = [p.grad for p in model.parameters()]
370
- vec = [torch.randn_like(p) for p in model.parameters()]
380
+ grads = [p.grad for p in model.parameters()]
381
+ vec = [torch.randn_like(p) for p in model.parameters()]
371
382
 
372
- # list of tensors, same layout as model.parameters()
373
- hvp(model.parameters(), grads, vec=vec)
383
+ # list of tensors, same layout as model.parameters()
384
+ hvp(model.parameters(), grads, vec=vec)
385
+ ```
374
386
  """
375
387
  params = list(params)
376
388
  g = list(grads)
@@ -393,23 +405,23 @@ def hvp_fd_central(
393
405
  Please note that this will clear :code:`grad` attributes in params.
394
406
 
395
407
  Example:
396
- .. code:: py
397
-
398
- model = nn.Linear(4, 2)
399
- X = torch.randn(10, 4)
400
- y = torch.randn(10, 2)
408
+ ```python
409
+ model = nn.Linear(4, 2)
410
+ X = torch.randn(10, 4)
411
+ y = torch.randn(10, 2)
401
412
 
402
- def closure():
403
- y_hat = model(X)
404
- loss = F.mse_loss(y_hat, y)
405
- model.zero_grad()
406
- loss.backward()
407
- return loss
413
+ def closure():
414
+ y_hat = model(X)
415
+ loss = F.mse_loss(y_hat, y)
416
+ model.zero_grad()
417
+ loss.backward()
418
+ return loss
408
419
 
409
- vec = [torch.randn_like(p) for p in model.parameters()]
420
+ vec = [torch.randn_like(p) for p in model.parameters()]
410
421
 
411
- # list of tensors, same layout as model.parameters()
412
- hvp_fd_central(closure, model.parameters(), vec=vec)
422
+ # list of tensors, same layout as model.parameters()
423
+ hvp_fd_central(closure, model.parameters(), vec=vec)
424
+ ```
413
425
  """
414
426
  params = list(params)
415
427
  vec = list(vec)
@@ -456,27 +468,27 @@ def hvp_fd_forward(
456
468
  Please note that this will clear :code:`grad` attributes in params.
457
469
 
458
470
  Example:
459
- .. code:: py
471
+ ```python
472
+ model = nn.Linear(4, 2)
473
+ X = torch.randn(10, 4)
474
+ y = torch.randn(10, 2)
460
475
 
461
- model = nn.Linear(4, 2)
462
- X = torch.randn(10, 4)
463
- y = torch.randn(10, 2)
464
-
465
- def closure():
466
- y_hat = model(X)
467
- loss = F.mse_loss(y_hat, y)
468
- model.zero_grad()
469
- loss.backward()
470
- return loss
476
+ def closure():
477
+ y_hat = model(X)
478
+ loss = F.mse_loss(y_hat, y)
479
+ model.zero_grad()
480
+ loss.backward()
481
+ return loss
471
482
 
472
- vec = [torch.randn_like(p) for p in model.parameters()]
483
+ vec = [torch.randn_like(p) for p in model.parameters()]
473
484
 
474
- # pre-compute gradient at initial point
475
- closure()
476
- g_0 = [p.grad for p in model.parameters()]
485
+ # pre-compute gradient at initial point
486
+ closure()
487
+ g_0 = [p.grad for p in model.parameters()]
477
488
 
478
- # list of tensors, same layout as model.parameters()
479
- hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
489
+ # list of tensors, same layout as model.parameters()
490
+ hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
491
+ ```
480
492
  """
481
493
 
482
494
  params = list(params)
@@ -485,7 +497,7 @@ def hvp_fd_forward(
485
497
 
486
498
  vec_norm = None
487
499
  if normalize:
488
- vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
500
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in vec])) # pylint:disable=not-callable
489
501
  if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
490
502
  vec = torch._foreach_div(vec, vec_norm)
491
503
 
@@ -1,5 +1,12 @@
1
- from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix_power_eigh, x_inv
1
+ from . import linear_operator
2
+ from .matrix_funcs import (
3
+ eigvals_func,
4
+ inv_sqrt_2x2,
5
+ matrix_power_eigh,
6
+ singular_vals_func,
7
+ x_inv,
8
+ )
2
9
  from .orthogonalize import gram_schmidt
3
10
  from .qr import qr_householder
11
+ from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
4
12
  from .svd import randomized_svd
5
- from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve