torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,15 @@ import torch
4
4
  import torch.autograd.forward_ad as fwAD
5
5
 
6
6
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
7
+ from .tensorlist import TensorList
7
8
 
8
- def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
9
- flat_input = torch.cat([i.reshape(-1) for i in output])
10
- grad_ouputs = torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype)
9
+ def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
10
+ flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
11
+ grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
11
12
  jac = []
12
- for i in range(flat_input.numel()):
13
+ for i in range(flat_outputs.numel()):
13
14
  jac.append(torch.autograd.grad(
14
- flat_input,
15
+ flat_outputs,
15
16
  wrt,
16
17
  grad_ouputs[i],
17
18
  retain_graph=True,
@@ -22,12 +23,12 @@ def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], creat
22
23
  return [torch.stack(z) for z in zip(*jac)]
23
24
 
24
25
 
25
- def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
26
- flat_input = torch.cat([i.reshape(-1) for i in output])
26
+ def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
27
+ flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
27
28
  return torch.autograd.grad(
28
- flat_input,
29
+ flat_outputs,
29
30
  wrt,
30
- torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype),
31
+ torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype),
31
32
  retain_graph=True,
32
33
  create_graph=create_graph,
33
34
  allow_unused=True,
@@ -51,13 +52,13 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
51
52
  return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
52
53
 
53
54
 
54
- def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
55
+ def jacobian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
55
56
  """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
56
57
  Returns a sequence of tensors with the length as `wrt`.
57
58
  Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
58
59
 
59
60
  Args:
60
- input (Sequence[torch.Tensor]): input sequence of tensors.
61
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
61
62
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
62
63
  create_graph (bool, optional):
63
64
  pytorch option, if True, graph of the derivative will be constructed,
@@ -68,16 +69,16 @@ def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], cr
68
69
  Returns:
69
70
  sequence of tensors with the length as `wrt`.
70
71
  """
71
- if batched: return _jacobian_batched(output, wrt, create_graph)
72
- return _jacobian(output, wrt, create_graph)
72
+ if batched: return _jacobian_batched(outputs, wrt, create_graph)
73
+ return _jacobian(outputs, wrt, create_graph)
73
74
 
74
- def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
75
+ def jacobian_and_hessian_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
75
76
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
76
77
  Calculating hessian requires calculating the jacobian. So this function is more efficient than
77
78
  calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
78
79
 
79
80
  Args:
80
- input (Sequence[torch.Tensor]): input sequence of tensors.
81
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
81
82
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
82
83
  create_graph (bool, optional):
83
84
  pytorch option, if True, graph of the derivative will be constructed,
@@ -87,7 +88,7 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
87
88
  Returns:
88
89
  tuple with jacobians sequence and hessians sequence.
89
90
  """
90
- jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
91
+ jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
91
92
  return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
92
93
 
93
94
 
@@ -96,13 +97,13 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
96
97
  # Note - I only tested this for cases where input is a scalar."""
97
98
  # return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
98
99
 
99
- def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
100
+ def jacobian_and_hessian_mat_wrt(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
100
101
  """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
101
102
  Calculating hessian requires calculating the jacobian. So this function is more efficient than
102
103
  calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
103
104
 
104
105
  Args:
105
- input (Sequence[torch.Tensor]): input sequence of tensors.
106
+ outputs (Sequence[torch.Tensor]): input sequence of tensors.
106
107
  wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
107
108
  create_graph (bool, optional):
108
109
  pytorch option, if True, graph of the derivative will be constructed,
@@ -112,7 +113,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
112
113
  Returns:
113
114
  tuple with jacobians sequence and hessians sequence.
114
115
  """
115
- jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
116
+ jac = jacobian_wrt(outputs, wrt, create_graph=True, batched = batched)
116
117
  H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
117
118
  return flatten_jacobian(jac), flatten_jacobian(H_list)
118
119
 
@@ -261,7 +262,7 @@ def jvp_fd_central(
261
262
  params: Iterable[torch.Tensor],
262
263
  tangent: Iterable[torch.Tensor],
263
264
  h=1e-3,
264
- normalize=False,
265
+ normalize=True,
265
266
  ) -> tuple[torch.Tensor | None, torch.Tensor]:
266
267
  """Jacobian vector product using central finite difference formula.
267
268
 
@@ -310,7 +311,7 @@ def jvp_fd_forward(
310
311
  tangent: Iterable[torch.Tensor],
311
312
  h=1e-3,
312
313
  v_0=None,
313
- normalize=False,
314
+ normalize=True,
314
315
  ) -> tuple[torch.Tensor | None, torch.Tensor]:
315
316
  """Jacobian vector product using forward finite difference formula.
316
317
  Loss at initial point can be specified in the `v_0` argument.
@@ -357,52 +358,18 @@ def jvp_fd_forward(
357
358
  if normalize: res = res * tangent_norm
358
359
  return v_0, res
359
360
 
360
- def hvp(
361
- params: Iterable[torch.Tensor],
362
- grads: Iterable[torch.Tensor],
363
- vec: Iterable[torch.Tensor],
364
- retain_graph=None,
365
- create_graph=False,
366
- allow_unused=None,
367
- ):
368
- """Hessian-vector product
369
-
370
- Example:
371
- ```python
372
- model = nn.Linear(4, 2)
373
- X = torch.randn(10, 4)
374
- y = torch.randn(10, 2)
375
-
376
- y_hat = model(X)
377
- loss = F.mse_loss(y_hat, y)
378
- loss.backward(create_graph=True)
379
-
380
- grads = [p.grad for p in model.parameters()]
381
- vec = [torch.randn_like(p) for p in model.parameters()]
382
-
383
- # list of tensors, same layout as model.parameters()
384
- hvp(model.parameters(), grads, vec=vec)
385
- ```
386
- """
387
- params = list(params)
388
- g = list(grads)
389
- vec = list(vec)
390
-
391
- with torch.enable_grad():
392
- return torch.autograd.grad(g, params, vec, create_graph=create_graph, retain_graph=retain_graph, allow_unused=allow_unused)
393
-
394
361
 
395
362
  @torch.no_grad
396
363
  def hvp_fd_central(
397
364
  closure,
398
365
  params: Iterable[torch.Tensor],
399
- vec: Iterable[torch.Tensor],
366
+ x: Iterable[torch.Tensor],
400
367
  h=1e-3,
401
- normalize=False,
368
+ normalize=True,
402
369
  ) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
403
- """Hessian-vector product using central finite difference formula.
370
+ """Returns ``(loss_approx, Hx)``.
404
371
 
405
- Please note that this will clear :code:`grad` attributes in params.
372
+ Please note that this will clear ``grad`` attributes in params.
406
373
 
407
374
  Example:
408
375
  ```python
@@ -424,48 +391,48 @@ def hvp_fd_central(
424
391
  ```
425
392
  """
426
393
  params = list(params)
427
- vec = list(vec)
394
+ x = list(x)
428
395
 
429
396
  vec_norm = None
430
397
  if normalize:
431
- vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
398
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in x])) # pylint:disable=not-callable
432
399
  if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
433
- vec = torch._foreach_div(vec, vec_norm)
400
+ x = torch._foreach_div(x, vec_norm)
434
401
 
435
- vec_h = torch._foreach_mul(vec, h)
436
- torch._foreach_add_(params, vec_h)
402
+ xh = torch._foreach_mul(x, h)
403
+ torch._foreach_add_(params, xh)
437
404
  with torch.enable_grad(): loss = closure()
438
405
  g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
439
406
 
440
- torch._foreach_sub_(params, vec_h)
441
- torch._foreach_sub_(params, vec_h)
407
+ torch._foreach_sub_(params, xh)
408
+ torch._foreach_sub_(params, xh)
442
409
  with torch.enable_grad(): loss = closure()
443
410
  g_minus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
444
411
 
445
- torch._foreach_add_(params, vec_h)
412
+ torch._foreach_add_(params, xh)
446
413
  for p in params: p.grad = None
447
414
 
448
- hvp_ = g_plus
449
- torch._foreach_sub_(hvp_, g_minus)
450
- torch._foreach_div_(hvp_, 2*h)
415
+ hx = g_plus
416
+ torch._foreach_sub_(hx, g_minus)
417
+ torch._foreach_div_(hx, 2*h)
451
418
 
452
- if normalize: torch._foreach_mul_(hvp_, vec_norm)
453
- return loss, hvp_
419
+ if normalize: torch._foreach_mul_(hx, vec_norm)
420
+ return loss, hx
454
421
 
455
422
  @torch.no_grad
456
423
  def hvp_fd_forward(
457
424
  closure,
458
425
  params: Iterable[torch.Tensor],
459
- vec: Iterable[torch.Tensor],
426
+ x: Iterable[torch.Tensor],
460
427
  h=1e-3,
461
428
  g_0=None,
462
- normalize=False,
429
+ normalize=True,
463
430
  ) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
464
- """Hessian-vector product using forward finite difference formula.
431
+ """Returns ``(loss_approx, Hx)``.
465
432
 
466
- Gradient at initial point can be specified in the `g_0` argument.
433
+ Gradient at initial point can be specified in the ``g_0`` argument.
467
434
 
468
- Please note that this will clear :code:`grad` attributes in params.
435
+ Please note that this will clear ``grad`` attributes in params.
469
436
 
470
437
  Example:
471
438
  ```python
@@ -492,16 +459,16 @@ def hvp_fd_forward(
492
459
  """
493
460
 
494
461
  params = list(params)
495
- vec = list(vec)
462
+ x = list(x)
496
463
  loss = None
497
464
 
498
465
  vec_norm = None
499
466
  if normalize:
500
- vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in vec])) # pylint:disable=not-callable
467
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in x])) # pylint:disable=not-callable
501
468
  if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
502
- vec = torch._foreach_div(vec, vec_norm)
469
+ x = torch._foreach_div(x, vec_norm)
503
470
 
504
- vec_h = torch._foreach_mul(vec, h)
471
+ xh = torch._foreach_mul(x, h)
505
472
 
506
473
  if g_0 is None:
507
474
  with torch.enable_grad(): loss = closure()
@@ -509,18 +476,75 @@ def hvp_fd_forward(
509
476
  else:
510
477
  g_0 = list(g_0)
511
478
 
512
- torch._foreach_add_(params, vec_h)
479
+ torch._foreach_add_(params, xh)
513
480
  with torch.enable_grad():
514
481
  l = closure()
515
482
  if loss is None: loss = l
516
483
  g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
517
484
 
518
- torch._foreach_sub_(params, vec_h)
485
+ torch._foreach_sub_(params, xh)
519
486
  for p in params: p.grad = None
520
487
 
521
- hvp_ = g_plus
522
- torch._foreach_sub_(hvp_, g_0)
523
- torch._foreach_div_(hvp_, h)
488
+ hx = g_plus
489
+ torch._foreach_sub_(hx, g_0)
490
+ torch._foreach_div_(hx, h)
491
+
492
+ if normalize: torch._foreach_mul_(hx, vec_norm)
493
+ return loss, hx
494
+
495
+ @torch.no_grad
496
+ def hessian_fd(fn, params: Sequence[torch.Tensor], eps: float = 1e-4, full: bool = True):
497
+ """returns ``f(x), g(x), H(x)``, where ``g(x)`` is a tensor list.
498
+
499
+ Number of evals for full is: 4n^2 - 2n + 1
500
+
501
+ Number of evals for upper is: 2n^2 + 1.
502
+ """
503
+ params = TensorList(params)
504
+ p_0 = params.clone()
505
+ n = sum(t.numel() for t in params)
506
+ device = params[0].device; dtype = params[0].dtype
507
+ fx = fn()
508
+ g = params.zeros_like()
509
+ H = torch.zeros((n, n), device=device, dtype=dtype)
510
+
511
+ for i in range(n):
512
+ for j in (range(n) if full else range(i, n)):
513
+ if i == j:
514
+ params.flat_set_lambda_(i, lambda x: x + eps)
515
+ f_plus = fn()
516
+
517
+ params.flat_set_lambda_(i, lambda x: x - 2 * eps)
518
+ f_minus = fn()
519
+
520
+ # params.flat_set_lambda_(i, lambda x: x + eps)
521
+ g.flat_set_(i, (f_plus - f_minus) / (2*eps))
522
+ H[i, i] = (f_plus - 2 * fx + f_minus) / (eps ** 2)
523
+
524
+ else:
525
+ params.flat_set_lambda_(i, lambda x: x + eps)
526
+ params.flat_set_lambda_(j, lambda x: x + eps)
527
+ f_pp = fn()
528
+
529
+ params.flat_set_lambda_(i, lambda x: x - 2 * eps)
530
+ f_np = fn()
531
+
532
+ params.flat_set_lambda_(j, lambda x: x - 2 * eps)
533
+ f_nn = fn()
534
+
535
+ params.flat_set_lambda_(i, lambda x: x + 2 * eps)
536
+ f_pn = fn()
537
+
538
+ # params.flat_set_lambda_(i, lambda x: x - eps)
539
+ # params.flat_set_lambda_(j, lambda x: x + eps)
540
+
541
+ H[i, j] = (f_pp - f_np - f_pn + f_nn) / (4 * eps ** 2)
542
+ if not full: H[j, i] = H[i, j]
543
+
544
+ params.copy_(p_0) # otherwise inaccuracy builds up
545
+
546
+ if full:
547
+ H = H + H.T
548
+ H /= 2
524
549
 
525
- if normalize: torch._foreach_mul_(hvp_, vec_norm)
526
- return loss, hvp_
550
+ return fx, g, H
@@ -64,22 +64,15 @@ def get_group_vals(param_groups: Iterable[Mapping[str, Any]],
64
64
  values[i].extend(group_value for _ in range(num_params))
65
65
  return values
66
66
 
67
- _InitLiterals = Literal['param', 'grad']
68
- Init = _InitLiterals | Any | list[_InitLiterals | Any] | tuple[_InitLiterals | Any]
67
+ Init = Any
69
68
 
70
- def _make_initial_state_value(param: torch.Tensor, init: Init, i: int | None):
71
- if callable(init): return init(param)
69
+ def _make_initial_state_value(tensor: torch.Tensor, init: Init, i: int | None):
70
+ if callable(init): return init(tensor)
72
71
  if isinstance(init, torch.Tensor): return init.detach().clone()
73
72
 
74
- if isinstance(init, str):
75
- if init in ('param','params'): return param.detach().clone()
76
- if init in ('grad', 'grads'):
77
- if param.grad is None: raise RuntimeError('init is set to "grad, but param.grad is None"')
78
- return param.grad.detach().clone()
79
-
80
73
  if isinstance(init, (list,tuple)):
81
74
  if i is None: raise RuntimeError(f'init is per-parameter ({type(init)}) but parameter index i is None')
82
- return _make_initial_state_value(param, init[i], None)
75
+ return _make_initial_state_value(tensor, init[i], None)
83
76
 
84
77
  return init
85
78
 
@@ -133,72 +126,6 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
133
126
  return values
134
127
 
135
128
 
136
- class Optimizer(torch.optim.Optimizer, ABC):
137
- """subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
138
-
139
- Args:
140
- params (iterable): an iterable of :class:`torch.Tensor` s or
141
- :class:`dict` s. Specifies what Tensors should be optimized.
142
- defaults (dict | None): a dict containing default values of optimization
143
- options (used when a parameter group doesn't specify them).
144
- """
145
- def __init__(self, params, defaults: dict[str, Any] | None = None, **_defaults):
146
- if defaults is None: defaults = {}
147
- defaults.update(_defaults)
148
-
149
- super().__init__(params, defaults)
150
- self.global_state = self.state[self.param_groups[0]['params'][0]]
151
- """state of 1st parameter, can be used as global state which is how L-BFGS uses it in pytorch, and there is some kind of good reason to do it like that"""
152
-
153
- def get_params(self, mode: ParamFilter = 'requires_grad', cls: type[ListLike] = TensorList) -> ListLike:
154
- return get_params(self.param_groups, mode, cls)
155
-
156
- @overload
157
- def group_vals(self, key: str, *,
158
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike: ...
159
- @overload
160
- def group_vals(self, key: list[str] | tuple[str,...], *,
161
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
162
- @overload
163
- def group_vals(self, key: str, key2: str, *keys: str,
164
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
165
-
166
- def group_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
167
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike | list[ListLike]:
168
- return get_group_vals(self.param_groups, key, key2, *keys, mode = mode, cls = cls) # pyright:ignore[reportArgumentType]
169
-
170
-
171
- @overload
172
- def state_vals(self, key: str, *,
173
- init: Init = torch.zeros_like,
174
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
175
- cls: type[ListLike] = TensorList) -> ListLike: ...
176
- @overload
177
- def state_vals(self, key: list[str] | tuple[str,...], *,
178
- init: Init | Sequence[Init] = torch.zeros_like,
179
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
180
- cls: type[ListLike] = TensorList) -> list[ListLike]: ...
181
- @overload
182
- def state_vals(self, key: str, key2: str, *keys: str,
183
- init: Init | Sequence[Init] = torch.zeros_like,
184
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
185
- cls: type[ListLike] = TensorList) -> list[ListLike]: ...
186
-
187
- def state_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
188
- init: Init | Sequence[Init] = torch.zeros_like,
189
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
190
- cls: type[ListLike] = TensorList) -> ListLike | list[ListLike]:
191
-
192
- if isinstance(mode, (list,tuple)): params = mode
193
- else: params = self.get_params(mode)
194
-
195
- return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
196
-
197
-
198
- # shut up pylance
199
- @abstractmethod
200
- def step(self, closure) -> Any: ... # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
201
-
202
129
  def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
203
130
  if set_to_none:
204
131
  for p in params:
@@ -1,3 +1,4 @@
1
+ import importlib
1
2
  import functools
2
3
  import operator
3
4
  from typing import Any, TypeVar, overload
@@ -40,6 +41,11 @@ def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable
40
41
  return any(i!=y for i in x)
41
42
  return any(i!=j for i,j in zip(x,y))
42
43
 
44
+ def generic_is_none(x: Any | Iterable[Any]):
45
+ """returns True if x is None or iterable with all elements set to None"""
46
+ if x is None: return True
47
+ if isinstance(x, Iterable): return all(i is None for i in x)
48
+ return False
43
49
 
44
50
  def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
45
51
  """If `other` is list/tuple, applies `fn` to self zipped with `other`.
@@ -68,3 +74,28 @@ def safe_dict_update_(d1_:dict, d2:dict):
68
74
  if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
69
75
  d1_.update(d2)
70
76
 
77
+ # lazy loader from https://stackoverflow.com/a/78312674/15673832
78
+ class LazyLoader:
79
+ 'thin shell class to wrap modules. load real module on first access and pass thru'
80
+
81
+ def __init__(self, modname):
82
+ self._modname = modname
83
+ self._mod = None
84
+
85
+ def __getattr__(self, attr):
86
+ 'import module on first attribute access'
87
+
88
+ try:
89
+ return getattr(self._mod, attr)
90
+
91
+ except Exception as e :
92
+ if self._mod is None :
93
+ # module is unset, load it
94
+ self._mod = importlib.import_module (self._modname)
95
+ else :
96
+ # module is set, got different exception from getattr (). reraise it
97
+ raise e
98
+
99
+ # retry getattr if module was just loaded for first time
100
+ # call this outside exception handler in case it raises new exception
101
+ return getattr (self._mod, attr)
@@ -22,7 +22,6 @@ from typing_extensions import Self, TypeAlias, Unpack
22
22
 
23
23
  from .metrics import Metrics, evaluate_metric, calculate_metric_list
24
24
  from .numberlist import NumberList, as_numberlist, maybe_numberlist
25
- from .ops import where_
26
25
  from .python_tools import generic_ne, zipmap
27
26
 
28
27
  _Scalar = int | float | bool | complex
@@ -346,6 +345,10 @@ class TensorList(list[torch.Tensor | Any]):
346
345
  def global_all(self): return builtins.all(self.all())
347
346
  def global_numel(self) -> int: return builtins.sum(self.numel())
348
347
 
348
+ def global_allclose(self, other: _TensorSeq, rtol: float = 0.00001, atol: float = 1e-8, equal_nan: bool = False) -> bool:
349
+ bools = self.zipmap_args(torch.allclose, other, rtol, atol, equal_nan)
350
+ return all(bools)
351
+
349
352
  def empty_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.empty_like(i, **kwargs) for i in self)
350
353
  def zeros_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.zeros_like(i, **kwargs) for i in self)
351
354
  def ones_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.ones_like(i, **kwargs) for i in self)
@@ -509,7 +512,6 @@ class TensorList(list[torch.Tensor | Any]):
509
512
  torch._foreach_mul_(self, other)
510
513
  return self
511
514
 
512
- # TODO: benchmark
513
515
  def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
514
516
  if generic_ne(other, 1):
515
517
  return self * other
@@ -536,6 +538,13 @@ class TensorList(list[torch.Tensor | Any]):
536
538
  torch._foreach_pow_(self, exponent)
537
539
  return self
538
540
 
541
+ def lazy_pow(self, other: int | float | list[int | float] | tuple[int | float]):
542
+ if generic_ne(other, 1): return self.pow(other)
543
+ return self
544
+ def lazy_pow_(self, other: int | float | list[int | float] | tuple[int | float]):
545
+ if generic_ne(other, 1): return self.pow_(other)
546
+ return self
547
+
539
548
  def rpow(self, input: _Scalar | _TensorSeq): return self.__class__(torch._foreach_pow(input, self))
540
549
  def rpow_(self, input: _TensorSeq):
541
550
  torch._foreach_pow_(input, self)
@@ -984,9 +993,6 @@ class TensorList(list[torch.Tensor | Any]):
984
993
  def where(self, condition: "torch.Tensor | _TensorSeq", other: _STOrSTSeq):
985
994
  """self where condition is true other otherwise"""
986
995
  return self.zipmap_args(_MethodCallerWithArgs('where'), condition, other)
987
- def where_(self, condition: "torch.Tensor | _TensorSeq", other: "torch.Tensor | _TensorSeq"):
988
- """self where condition is true other otherwise"""
989
- return self.zipmap_args_inplace_(where_, condition, other)
990
996
 
991
997
  def masked_fill(self, mask: "torch.Tensor | _TensorSeq", fill_value: "_Scalar | _ScalarSeq"):
992
998
  """Same as tensor[mask] = value (not in-place), where value must be scalar/scalars"""
@@ -0,0 +1,68 @@
1
+ import itertools
2
+ from collections.abc import Callable
3
+ from importlib.util import find_spec
4
+ from typing import TYPE_CHECKING, cast
5
+
6
+ import torch
7
+
8
+ from .python_tools import LazyLoader
9
+
10
+ lazy_thoad = LazyLoader("thoad")
11
+ if TYPE_CHECKING:
12
+ import thoad
13
+ lazy_thoad = cast(thoad, lazy_thoad)
14
+
15
+ def thoad_single_tensor(
16
+ ctrl: "thoad.Controller",
17
+ params: list[torch.Tensor],
18
+ order: int
19
+ ) -> torch.Tensor:
20
+ """treats params as if they were concatenated into a vector."""
21
+
22
+ if not all(p.requires_grad for p in params):
23
+ raise ValueError("All parameters must have requires_grad=True")
24
+
25
+ if order < 1:
26
+ raise ValueError("Order must be at least 1")
27
+
28
+ # we need parameter sizes and total size N
29
+ # final tensor is (N, N, ..., N) with `order` dimensions.
30
+ param_numels = [p.numel() for p in params]
31
+ total_params = sum(param_numels)
32
+
33
+ final_shape = (total_params,) * order
34
+ p = params[0]
35
+ T = torch.zeros(final_shape, device=p.device, dtype=p.dtype)
36
+
37
+ # start/end indices for each parameter in the flattened vector.
38
+ offsets = torch.cumsum(torch.tensor([0] + param_numels), dim=0)
39
+
40
+ # for order=2 this iterates through (p0,p0), (p0,p1), (p1,p0), (p1,p1), etc.
41
+ param_indices = range(len(params))
42
+ for block_indices in itertools.product(param_indices, repeat=order):
43
+
44
+ block_params = tuple(params[i] for i in block_indices)
45
+ block_tensor, _ = ctrl.fetch_hgrad(variables=block_params) # (1, *p1.shape, *p2.shape, ...).
46
+ block_tensor = block_tensor.squeeze(0) # (*p1.shape, *p2.shape, ...)
47
+
48
+ # convert (*p1.shape, *p2.shape) to (p1.numel(), p2.numel())
49
+ block_flat_shape = tuple(param_numels[i] for i in block_indices)
50
+ block_tensor_flat = block_tensor.reshape(block_flat_shape)
51
+
52
+ # place the flattened block into T
53
+ slicing = tuple(
54
+ slice(offsets[i], offsets[i+1]) for i in block_indices
55
+ )
56
+ T[slicing] = block_tensor_flat
57
+
58
+ ctrl.clear()
59
+ return T
60
+
61
+ def thoad_derivatives(
62
+ ctrl: "thoad.Controller",
63
+ params: list[torch.Tensor],
64
+ order: int,
65
+ ):
66
+ """returns all derivatives up to ``order`` in ascending order, all as single tensors
67
+ as if parameters were concatenated to a vector"""
68
+ return [thoad_single_tensor(ctrl, params, o) for o in range(1, order+1)]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.14
3
+ Version: 0.4.0
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero