torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,5 +1,11 @@
1
1
  from . import tensorlist as tl
2
- from .compile import _optional_compiler, benchmark_compile_cpu, benchmark_compile_cuda, set_compilation, enable_compilation
2
+ from .compile import (
3
+ _optional_compiler,
4
+ benchmark_compile_cpu,
5
+ benchmark_compile_cuda,
6
+ enable_compilation,
7
+ set_compilation,
8
+ )
3
9
  from .numberlist import NumberList
4
10
  from .optimizer import (
5
11
  Init,
@@ -18,6 +24,36 @@ from .params import (
18
24
  _copy_param_groups,
19
25
  _make_param_groups,
20
26
  )
21
- from .python_tools import flatten, generic_eq, generic_ne, reduce_dim, unpack_dicts
22
- from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like, generic_finfo_eps
23
- from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_
27
+ from .python_tools import (
28
+ flatten,
29
+ generic_eq,
30
+ generic_ne,
31
+ reduce_dim,
32
+ safe_dict_update_,
33
+ unpack_dicts,
34
+ )
35
+ from .tensorlist import (
36
+ Distributions,
37
+ Metrics,
38
+ TensorList,
39
+ as_tensorlist,
40
+ generic_clamp,
41
+ generic_finfo,
42
+ generic_finfo_eps,
43
+ generic_finfo_tiny,
44
+ generic_max,
45
+ generic_numel,
46
+ generic_randn_like,
47
+ generic_sum,
48
+ generic_vector_norm,
49
+ generic_zeros_like,
50
+ )
51
+ from .torch_tools import (
52
+ set_storage_,
53
+ tofloat,
54
+ tolist,
55
+ tonumpy,
56
+ totensor,
57
+ vec_to_tensors,
58
+ vec_to_tensors_,
59
+ )
@@ -38,7 +38,7 @@ class _MaybeCompiledFunc:
38
38
  _optional_compiler = _OptionalCompiler()
39
39
  """this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
40
40
 
41
- def set_compilation(enable: bool):
41
+ def set_compilation(enable: bool=True):
42
42
  """`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
43
43
  _optional_compiler.enable = enable
44
44
 
@@ -2,7 +2,6 @@ from collections.abc import Iterable, Sequence
2
2
 
3
3
  import torch
4
4
  import torch.autograd.forward_ad as fwAD
5
- from typing import Literal
6
5
 
7
6
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
8
7
 
@@ -35,10 +34,27 @@ def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor
35
34
  is_grads_batched=True,
36
35
  )
37
36
 
37
+ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
38
+ """Converts the output of jacobian_wrt (a list of tensors) into a single 2D matrix.
39
+
40
+ Args:
41
+ jacs (Sequence[torch.Tensor]):
42
+ output from jacobian_wrt where ach tensor has the shape `(*output.shape, *wrt[i].shape)`.
43
+
44
+ Returns:
45
+ torch.Tensor: has the shape `(output.ndim, wrt.ndim)`.
46
+ """
47
+ if not jacs:
48
+ return torch.empty(0, 0)
49
+
50
+ n_out = jacs[0].shape[0]
51
+ return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
52
+
53
+
38
54
  def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
39
55
  """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
40
56
  Returns a sequence of tensors with the length as `wrt`.
41
- Each tensor will have the shape `(*input.shape, *wrt[i].shape)`.
57
+ Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
42
58
 
43
59
  Args:
44
60
  input (Sequence[torch.Tensor]): input sequence of tensors.
@@ -75,10 +91,10 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
75
91
  return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
76
92
 
77
93
 
78
- def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
79
- """takes output of `hessian` and returns the 2D hessian matrix.
80
- Note - I only tested this for cases where input is a scalar."""
81
- return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
94
+ # def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
95
+ # """takes output of `hessian` and returns the 2D hessian matrix.
96
+ # Note - I only tested this for cases where input is a scalar."""
97
+ # return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
82
98
 
83
99
  def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
84
100
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
@@ -98,7 +114,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
98
114
  """
99
115
  jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
100
116
  H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
101
- return torch.cat([j.view(-1) for j in jac]), hessian_list_to_mat(H_list)
117
+ return flatten_jacobian(jac), flatten_jacobian(H_list)
102
118
 
103
119
  def hessian(
104
120
  fn,
@@ -115,19 +131,18 @@ def hessian(
115
131
  `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
116
132
 
117
133
  Example:
118
- .. code:: py
119
-
120
- model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
121
- X = torch.randn(10, 4)
122
- y = torch.randn(10, 2)
134
+ ```python
135
+ model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
136
+ X = torch.randn(10, 4)
137
+ y = torch.randn(10, 2)
123
138
 
124
- def fn():
125
- y_hat = model(X)
126
- loss = F.mse_loss(y_hat, y)
127
- return loss
128
-
129
- hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
139
+ def fn():
140
+ y_hat = model(X)
141
+ loss = F.mse_loss(y_hat, y)
142
+ return loss
130
143
 
144
+ hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
145
+ ```
131
146
 
132
147
  """
133
148
  params = list(params)
@@ -165,19 +180,18 @@ def hessian_mat(
165
180
  `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
166
181
 
167
182
  Example:
168
- .. code:: py
169
-
170
- model = nn.Linear(4, 2) # 10 parameters in total
171
- X = torch.randn(10, 4)
172
- y = torch.randn(10, 2)
183
+ ```python
184
+ model = nn.Linear(4, 2) # 10 parameters in total
185
+ X = torch.randn(10, 4)
186
+ y = torch.randn(10, 2)
173
187
 
174
- def fn():
175
- y_hat = model(X)
176
- loss = F.mse_loss(y_hat, y)
177
- return loss
178
-
179
- hessian_mat(fn, model.parameters()) # 10x10 tensor
188
+ def fn():
189
+ y_hat = model(X)
190
+ loss = F.mse_loss(y_hat, y)
191
+ return loss
180
192
 
193
+ hessian_mat(fn, model.parameters()) # 10x10 tensor
194
+ ```
181
195
 
182
196
  """
183
197
  params = list(params)
@@ -206,21 +220,20 @@ def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) ->
206
220
  """Jacobian vector product.
207
221
 
208
222
  Example:
209
- .. code:: py
210
-
211
- model = nn.Linear(4, 2)
212
- X = torch.randn(10, 4)
213
- y = torch.randn(10, 2)
214
-
215
- tangent = [torch.randn_like(p) for p in model.parameters()]
223
+ ```python
224
+ model = nn.Linear(4, 2)
225
+ X = torch.randn(10, 4)
226
+ y = torch.randn(10, 2)
216
227
 
217
- def fn():
218
- y_hat = model(X)
219
- loss = F.mse_loss(y_hat, y)
220
- return loss
228
+ tangent = [torch.randn_like(p) for p in model.parameters()]
221
229
 
222
- jvp(fn, model.parameters(), tangent) # scalar
230
+ def fn():
231
+ y_hat = model(X)
232
+ loss = F.mse_loss(y_hat, y)
233
+ return loss
223
234
 
235
+ jvp(fn, model.parameters(), tangent) # scalar
236
+ ```
224
237
  """
225
238
  params = list(params)
226
239
  tangent = list(tangent)
@@ -253,21 +266,20 @@ def jvp_fd_central(
253
266
  """Jacobian vector product using central finite difference formula.
254
267
 
255
268
  Example:
256
- .. code:: py
257
-
258
- model = nn.Linear(4, 2)
259
- X = torch.randn(10, 4)
260
- y = torch.randn(10, 2)
261
-
262
- tangent = [torch.randn_like(p) for p in model.parameters()]
269
+ ```python
270
+ model = nn.Linear(4, 2)
271
+ X = torch.randn(10, 4)
272
+ y = torch.randn(10, 2)
263
273
 
264
- def fn():
265
- y_hat = model(X)
266
- loss = F.mse_loss(y_hat, y)
267
- return loss
274
+ tangent = [torch.randn_like(p) for p in model.parameters()]
268
275
 
269
- jvp_fd_central(fn, model.parameters(), tangent) # scalar
276
+ def fn():
277
+ y_hat = model(X)
278
+ loss = F.mse_loss(y_hat, y)
279
+ return loss
270
280
 
281
+ jvp_fd_central(fn, model.parameters(), tangent) # scalar
282
+ ```
271
283
  """
272
284
  params = list(params)
273
285
  tangent = list(tangent)
@@ -304,24 +316,24 @@ def jvp_fd_forward(
304
316
  Loss at initial point can be specified in the `v_0` argument.
305
317
 
306
318
  Example:
307
- .. code:: py
319
+ ```python
320
+ model = nn.Linear(4, 2)
321
+ X = torch.randn(10, 4)
322
+ y = torch.randn(10, 2)
308
323
 
309
- model = nn.Linear(4, 2)
310
- X = torch.randn(10, 4)
311
- y = torch.randn(10, 2)
324
+ tangent1 = [torch.randn_like(p) for p in model.parameters()]
325
+ tangent2 = [torch.randn_like(p) for p in model.parameters()]
312
326
 
313
- tangent1 = [torch.randn_like(p) for p in model.parameters()]
314
- tangent2 = [torch.randn_like(p) for p in model.parameters()]
315
-
316
- def fn():
317
- y_hat = model(X)
318
- loss = F.mse_loss(y_hat, y)
319
- return loss
327
+ def fn():
328
+ y_hat = model(X)
329
+ loss = F.mse_loss(y_hat, y)
330
+ return loss
320
331
 
321
- v_0 = fn() # pre-calculate loss at initial point
332
+ v_0 = fn() # pre-calculate loss at initial point
322
333
 
323
- jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
324
- jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
334
+ jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
335
+ jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
336
+ ```
325
337
 
326
338
  """
327
339
  params = list(params)
@@ -356,21 +368,21 @@ def hvp(
356
368
  """Hessian-vector product
357
369
 
358
370
  Example:
359
- .. code:: py
360
-
361
- model = nn.Linear(4, 2)
362
- X = torch.randn(10, 4)
363
- y = torch.randn(10, 2)
371
+ ```python
372
+ model = nn.Linear(4, 2)
373
+ X = torch.randn(10, 4)
374
+ y = torch.randn(10, 2)
364
375
 
365
- y_hat = model(X)
366
- loss = F.mse_loss(y_hat, y)
367
- loss.backward(create_graph=True)
376
+ y_hat = model(X)
377
+ loss = F.mse_loss(y_hat, y)
378
+ loss.backward(create_graph=True)
368
379
 
369
- grads = [p.grad for p in model.parameters()]
370
- vec = [torch.randn_like(p) for p in model.parameters()]
380
+ grads = [p.grad for p in model.parameters()]
381
+ vec = [torch.randn_like(p) for p in model.parameters()]
371
382
 
372
- # list of tensors, same layout as model.parameters()
373
- hvp(model.parameters(), grads, vec=vec)
383
+ # list of tensors, same layout as model.parameters()
384
+ hvp(model.parameters(), grads, vec=vec)
385
+ ```
374
386
  """
375
387
  params = list(params)
376
388
  g = list(grads)
@@ -393,23 +405,23 @@ def hvp_fd_central(
393
405
  Please note that this will clear :code:`grad` attributes in params.
394
406
 
395
407
  Example:
396
- .. code:: py
397
-
398
- model = nn.Linear(4, 2)
399
- X = torch.randn(10, 4)
400
- y = torch.randn(10, 2)
408
+ ```python
409
+ model = nn.Linear(4, 2)
410
+ X = torch.randn(10, 4)
411
+ y = torch.randn(10, 2)
401
412
 
402
- def closure():
403
- y_hat = model(X)
404
- loss = F.mse_loss(y_hat, y)
405
- model.zero_grad()
406
- loss.backward()
407
- return loss
413
+ def closure():
414
+ y_hat = model(X)
415
+ loss = F.mse_loss(y_hat, y)
416
+ model.zero_grad()
417
+ loss.backward()
418
+ return loss
408
419
 
409
- vec = [torch.randn_like(p) for p in model.parameters()]
420
+ vec = [torch.randn_like(p) for p in model.parameters()]
410
421
 
411
- # list of tensors, same layout as model.parameters()
412
- hvp_fd_central(closure, model.parameters(), vec=vec)
422
+ # list of tensors, same layout as model.parameters()
423
+ hvp_fd_central(closure, model.parameters(), vec=vec)
424
+ ```
413
425
  """
414
426
  params = list(params)
415
427
  vec = list(vec)
@@ -456,27 +468,27 @@ def hvp_fd_forward(
456
468
  Please note that this will clear :code:`grad` attributes in params.
457
469
 
458
470
  Example:
459
- .. code:: py
471
+ ```python
472
+ model = nn.Linear(4, 2)
473
+ X = torch.randn(10, 4)
474
+ y = torch.randn(10, 2)
460
475
 
461
- model = nn.Linear(4, 2)
462
- X = torch.randn(10, 4)
463
- y = torch.randn(10, 2)
464
-
465
- def closure():
466
- y_hat = model(X)
467
- loss = F.mse_loss(y_hat, y)
468
- model.zero_grad()
469
- loss.backward()
470
- return loss
476
+ def closure():
477
+ y_hat = model(X)
478
+ loss = F.mse_loss(y_hat, y)
479
+ model.zero_grad()
480
+ loss.backward()
481
+ return loss
471
482
 
472
- vec = [torch.randn_like(p) for p in model.parameters()]
483
+ vec = [torch.randn_like(p) for p in model.parameters()]
473
484
 
474
- # pre-compute gradient at initial point
475
- closure()
476
- g_0 = [p.grad for p in model.parameters()]
485
+ # pre-compute gradient at initial point
486
+ closure()
487
+ g_0 = [p.grad for p in model.parameters()]
477
488
 
478
- # list of tensors, same layout as model.parameters()
479
- hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
489
+ # list of tensors, same layout as model.parameters()
490
+ hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
491
+ ```
480
492
  """
481
493
 
482
494
  params = list(params)
@@ -485,7 +497,7 @@ def hvp_fd_forward(
485
497
 
486
498
  vec_norm = None
487
499
  if normalize:
488
- vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
500
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in vec])) # pylint:disable=not-callable
489
501
  if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
490
502
  vec = torch._foreach_div(vec, vec_norm)
491
503
 
@@ -1,5 +1,12 @@
1
- from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix_power_eigh, x_inv
1
+ from . import linear_operator
2
+ from .matrix_funcs import (
3
+ eigvals_func,
4
+ inv_sqrt_2x2,
5
+ matrix_power_eigh,
6
+ singular_vals_func,
7
+ x_inv,
8
+ )
2
9
  from .orthogonalize import gram_schmidt
3
10
  from .qr import qr_householder
11
+ from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
4
12
  from .svd import randomized_svd
5
- from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve, steihaug_toint_cg