torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,147 +1,119 @@
1
- import warnings
2
1
  from collections.abc import Callable
3
- from functools import partial
4
- from typing import Literal
2
+ from typing import Any
5
3
 
6
4
  import torch
7
5
 
8
- from ...core import Chainable, Module, apply_transform, Var
9
- from ...utils import TensorList, vec_to_tensors
10
- from ...utils.derivatives import (
11
- flatten_jacobian,
12
- hessian_mat,
13
- hvp,
14
- hvp_fd_central,
15
- hvp_fd_forward,
16
- jacobian_and_hessian_wrt,
17
- )
18
- from ...utils.linalg.linear_operator import DenseWithInverse, Dense
19
-
20
- def _lu_solve(H: torch.Tensor, g: torch.Tensor):
6
+ from ...core import Chainable, Transform, Objective, HessianMethod
7
+ from ...utils import vec_to_tensors_
8
+ from ...linalg.linear_operator import Dense, DenseWithInverse, Eigendecomposition
9
+ from ...linalg import torch_linalg
10
+
11
+ def _try_lu_solve(H: torch.Tensor, g: torch.Tensor):
21
12
  try:
22
- x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
13
+ x, info = torch_linalg.solve_ex(H, g, retry_float64=True)
23
14
  if info == 0: return x
24
15
  return None
25
16
  except RuntimeError:
26
17
  return None
27
18
 
28
- def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
29
- x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
19
+ def _try_cholesky_solve(H: torch.Tensor, g: torch.Tensor):
20
+ L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
30
21
  if info == 0:
31
- g.unsqueeze_(1)
32
- return torch.cholesky_solve(g, x)
22
+ return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
33
23
  return None
34
24
 
35
25
  def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
36
26
  return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
37
27
 
38
- def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
39
- try:
40
- L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
41
- if tfm is not None: L = tfm(L)
42
- if search_negative and L[0] < 0:
43
- neg_mask = L < 0
44
- Q_neg = Q[:, neg_mask] * L[neg_mask]
45
- return (Q_neg * (g @ Q_neg).sign()).mean(1)
46
-
47
- return Q @ ((Q.mH @ g) / L)
48
-
49
- except torch.linalg.LinAlgError:
50
- return None
51
-
52
-
53
- def _get_loss_grad_and_hessian(var: Var, hessian_method:str, vectorize:bool):
54
- """returns (loss, g_list, H). Also sets var.loss and var.grad.
55
- If hessian_method isn't 'autograd', loss is not set and returned as None"""
56
- closure = var.closure
57
- if closure is None:
58
- raise RuntimeError("Second order methods requires a closure to be provided to the `step` method.")
59
-
60
- params = var.params
61
-
62
- # ------------------------ calculate grad and hessian ------------------------ #
63
- loss = None
64
- if hessian_method == 'autograd':
65
- with torch.enable_grad():
66
- loss = var.loss = var.loss_approx = closure(False)
67
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
68
- g_list = [t[0] for t in g_list] # remove leading dim from loss
69
- var.grad = g_list
70
- H = flatten_jacobian(H_list)
71
-
72
- elif hessian_method in ('func', 'autograd.functional'):
73
- strat = 'forward-mode' if vectorize else 'reverse-mode'
74
- with torch.enable_grad():
75
- g_list = var.get_grad(retain_graph=True)
76
- H = hessian_mat(partial(closure, backward=False), params,
77
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
78
-
79
- else:
80
- raise ValueError(hessian_method)
81
-
82
- return loss, g_list, H
83
-
84
- def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None) -> torch.Tensor:
85
- """returns the update tensor, then do vec_to_tensor(update, params)"""
86
- params = var.params
87
-
28
+ def _newton_update_state_(
29
+ state: dict,
30
+ H: torch.Tensor,
31
+ damping: float,
32
+ eigval_fn: Callable | None,
33
+ precompute_inverse: bool,
34
+ use_lstsq: bool,
35
+ ):
36
+ """used in most hessian-based modules"""
37
+ # add damping
88
38
  if damping != 0:
89
- H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
90
-
91
- # -------------------------------- inner step -------------------------------- #
92
- update = var.get_update()
93
- if inner is not None:
94
- update = apply_transform(inner, update, params=params, grads=var.grad, loss=var.loss, var=var)
95
-
96
- g = torch.cat([t.ravel() for t in update])
97
- if g_proj is not None: g = g_proj(g)
98
-
99
- # ----------------------------------- solve ---------------------------------- #
100
- update = None
101
-
102
- if H_tfm is not None:
103
- ret = H_tfm(H, g)
104
-
105
- if isinstance(ret, torch.Tensor):
106
- update = ret
107
-
108
- else: # returns (H, is_inv)
109
- H, is_inv = ret
110
- if is_inv: update = H @ g
39
+ reg = torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(damping)
40
+ H += reg
111
41
 
42
+ # if eigval_fn is given, we don't need H or H_inv, we store factors
112
43
  if eigval_fn is not None:
113
- update = _eigh_solve(H, g, eigval_fn, search_negative=False)
114
-
115
- if update is None and use_lstsq: update = _least_squares_solve(H, g)
116
- if update is None: update = _cholesky_solve(H, g)
117
- if update is None: update = _lu_solve(H, g)
118
- if update is None: update = _least_squares_solve(H, g)
119
-
120
- return update
121
-
122
- def _get_H(H: torch.Tensor, eigval_fn):
123
- if eigval_fn is not None:
124
- try:
125
- L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
126
- L: torch.Tensor = eigval_fn(L)
127
- H = Q @ L.diag_embed() @ Q.mH
128
- H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
129
- return DenseWithInverse(H, H_inv)
130
-
131
- except torch.linalg.LinAlgError:
132
- pass
133
-
134
- return Dense(H)
44
+ L, Q = torch_linalg.eigh(H, retry_float64=True)
45
+ L = eigval_fn(L)
46
+ state["L"] = L
47
+ state["Q"] = Q
48
+ return
49
+
50
+ # pre-compute inverse if requested
51
+ # store H to as it is needed for trust regions
52
+ state["H"] = H
53
+ if precompute_inverse:
54
+ if use_lstsq:
55
+ H_inv = torch.linalg.pinv(H) # pylint:disable=not-callable
56
+ else:
57
+ H_inv, _ = torch_linalg.inv_ex(H)
58
+ state["H_inv"] = H_inv
59
+
60
+
61
+ def _newton_solve(
62
+ b: torch.Tensor,
63
+ state: dict[str, torch.Tensor | Any],
64
+ use_lstsq: bool = False,
65
+ ):
66
+ """
67
+ used in most hessian-based modules. state is from ``_newton_update_state_``, in it:
135
68
 
136
- class Newton(Module):
137
- """Exact newton's method via autograd.
69
+ H (torch.Tensor): hessian
70
+ H_inv (torch.Tensor | None): hessian inverse
71
+ L (torch.Tensor | None): eigenvalues (transformed)
72
+ Q (torch.Tensor | None): eigenvectors
73
+ """
74
+ # use eig if provided
75
+ if "L" in state:
76
+ Q = state["Q"]; L = state["L"]
77
+ assert Q is not None
78
+ return Q @ ((Q.mH @ b) / L)
79
+
80
+ # use inverse if cached
81
+ if "H_inv" in state:
82
+ return state["H_inv"] @ b
83
+
84
+ # use hessian
85
+ H = state["H"]
86
+ if use_lstsq: return _least_squares_solve(H, b)
87
+
88
+ dir = None
89
+ if dir is None: dir = _try_cholesky_solve(H, b)
90
+ if dir is None: dir = _try_lu_solve(H, b)
91
+ if dir is None: dir = _least_squares_solve(H, b)
92
+ return dir
93
+
94
+ def _newton_get_H(state: dict[str, torch.Tensor | Any]):
95
+ """used in most hessian-based modules. state is from ``_newton_update_state_``"""
96
+ if "H_inv" in state:
97
+ return DenseWithInverse(state["H"], state["H_inv"])
98
+
99
+ if "L" in state:
100
+ # Eigendecomposition has sligthly different solve_plus_diag
101
+ # I am pretty sure it should be very close and it uses no solves
102
+ # best way to test is to try cubic regularization with this
103
+ return Eigendecomposition(state["L"], state["Q"], use_nystrom=False)
104
+
105
+ return Dense(state["H"])
106
+
107
+ class Newton(Transform):
108
+ """Exact Newton's method via autograd.
138
109
 
139
110
  Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
140
111
  The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
112
+
141
113
  ``g`` can be output of another module, if it is specifed in ``inner`` argument.
142
114
 
143
115
  Note:
144
- In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
116
+ In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
145
117
 
146
118
  Note:
147
119
  This module requires the a closure passed to the optimizer step,
@@ -149,38 +121,43 @@ class Newton(Module):
149
121
  The closure must accept a ``backward`` argument (refer to documentation).
150
122
 
151
123
  Args:
152
- damping (float, optional): tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.
153
- search_negative (bool, Optional):
154
- if True, whenever a negative eigenvalue is detected,
155
- search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
156
- use_lstsq (bool, Optional):
157
- if True, least squares will be used to solve the linear system, this may generate reasonable directions
158
- when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
159
- If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
160
- argument will be ignored.
161
- hessian_method (str):
162
- how to calculate hessian. Defaults to "autograd".
163
- vectorize (bool, optional):
164
- whether to enable vectorized hessian. Defaults to True.
165
- H_tfm (Callable | None, optional):
166
- optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
167
-
168
- must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
169
- which must be True if transform inverted the hessian and False otherwise.
170
-
171
- Or it returns a single tensor which is used as the update.
172
-
173
- Defaults to None.
124
+ damping (float, optional): tikhonov regularizer value. Defaults to 0.
174
125
  eigval_fn (Callable | None, optional):
175
- optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
126
+ function to apply to eigenvalues, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
176
127
  If this is specified, eigendecomposition will be used to invert the hessian.
128
+ update_freq (int, optional):
129
+ updates hessian every ``update_freq`` steps.
130
+ precompute_inverse (bool, optional):
131
+ if ``True``, whenever hessian is computed, also computes the inverse. This is more efficient
132
+ when ``update_freq`` is large. If ``None``, this is ``True`` if ``update_freq >= 10``.
133
+ use_lstsq (bool, Optional):
134
+ if True, least squares will be used to solve the linear system, this can prevent it from exploding
135
+ when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares.
136
+ If ``eigval_fn`` is specified, eigendecomposition is always used and this argument is ignored.
137
+ hessian_method (str):
138
+ Determines how hessian is computed.
139
+
140
+ - ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
141
+ - ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
142
+ - ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
143
+ - ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
144
+ - ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
145
+ - ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
146
+ - ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
147
+ - ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
148
+ - ``"thoad"`` - uses ``thoad`` library, can be significantly faster than pytorch but limited operator coverage.
149
+
150
+ Defaults to ``"batched_autograd"``.
151
+ h (float, optional):
152
+ finite difference step size if hessian is compute via finite-difference.
177
153
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
178
154
 
179
155
  # See also
180
156
 
181
- * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products,
157
+ * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products.
182
158
  useful for large scale problems as it doesn't form the full hessian.
183
159
  * ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
160
+ * ``tz.m.ImprovedNewton``: Newton with additional rank one correction to the hessian, can be faster than Newton.
184
161
  * ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
185
162
  * ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
186
163
 
@@ -189,57 +166,48 @@ class Newton(Module):
189
166
  ## Implementation details
190
167
 
191
168
  ``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
192
- The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
193
- Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
169
+ The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting ``use_lstsq=True``.
194
170
 
195
171
  Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
196
- ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive,
197
- but not by much
172
+ ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.
198
173
 
199
174
  ## Handling non-convexity
200
175
 
201
176
  Standard Newton's method does not handle non-convexity well without some modifications.
202
177
  This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
203
178
 
204
- The first modification to handle non-convexity is to modify the eignevalues to be positive,
179
+ A modification to handle non-convexity is to modify the eignevalues to be positive,
205
180
  for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
206
181
 
207
- Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
208
- This also requires an eigendecomposition.
209
-
210
- The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
211
- but that may be significantly less efficient.
212
-
213
182
  # Examples:
214
183
 
215
184
  Newton's method with backtracking line search
216
185
 
217
186
  ```py
218
- opt = tz.Modular(
187
+ opt = tz.Optimizer(
219
188
  model.parameters(),
220
189
  tz.m.Newton(),
221
190
  tz.m.Backtracking()
222
191
  )
223
192
  ```
224
193
 
225
- Newton preconditioning applied to momentum
194
+ Newton's method for non-convex optimization.
226
195
 
227
196
  ```py
228
- opt = tz.Modular(
197
+ opt = tz.Optimizer(
229
198
  model.parameters(),
230
- tz.m.Newton(inner=tz.m.EMA(0.9)),
231
- tz.m.LR(0.1)
199
+ tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
200
+ tz.m.Backtracking()
232
201
  )
233
202
  ```
234
203
 
235
- Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient,
236
- but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
204
+ Newton preconditioning applied to momentum
237
205
 
238
206
  ```py
239
- opt = tz.Modular(
207
+ opt = tz.Optimizer(
240
208
  model.parameters(),
241
- tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
242
- tz.m.Backtracking()
209
+ tz.m.Newton(inner=tz.m.EMA(0.9)),
210
+ tz.m.LR(0.1)
243
211
  )
244
212
  ```
245
213
 
@@ -247,47 +215,48 @@ class Newton(Module):
247
215
  def __init__(
248
216
  self,
249
217
  damping: float = 0,
250
- use_lstsq: bool = False,
251
- update_freq: int = 1,
252
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
253
- vectorize: bool = True,
254
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
255
218
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
219
+ update_freq: int = 1,
220
+ precompute_inverse: bool | None = None,
221
+ use_lstsq: bool = False,
222
+ hessian_method: HessianMethod = "batched_autograd",
223
+ h: float = 1e-3,
256
224
  inner: Chainable | None = None,
257
225
  ):
258
- defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
259
- super().__init__(defaults)
260
-
261
- if inner is not None:
262
- self.set_child('inner', inner)
226
+ defaults = locals().copy()
227
+ del defaults['self'], defaults['update_freq'], defaults["inner"]
228
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
263
229
 
264
230
  @torch.no_grad
265
- def update(self, var):
266
- step = self.global_state.get('step', 0)
267
- self.global_state['step'] = step + 1
268
-
269
- if step % self.defaults['update_freq'] == 0:
270
- loss, g_list, self.global_state['H'] = _get_loss_grad_and_hessian(
271
- var, self.defaults['hessian_method'], self.defaults['vectorize']
272
- )
231
+ def update_states(self, objective, states, settings):
232
+ fs = settings[0]
233
+
234
+ precompute_inverse = fs["precompute_inverse"]
235
+ if precompute_inverse is None:
236
+ precompute_inverse = fs["__update_freq"] >= 10
237
+
238
+ __, _, H = objective.hessian(hessian_method=fs["hessian_method"], h=fs["h"], at_x0=True)
239
+
240
+ _newton_update_state_(
241
+ state = self.global_state,
242
+ H=H,
243
+ damping = fs["damping"],
244
+ eigval_fn = fs["eigval_fn"],
245
+ precompute_inverse = precompute_inverse,
246
+ use_lstsq = fs["use_lstsq"]
247
+ )
273
248
 
274
249
  @torch.no_grad
275
- def apply(self, var):
276
- params = var.params
277
- update = _newton_step(
278
- var=var,
279
- H = self.global_state["H"],
280
- damping=self.defaults["damping"],
281
- inner=self.children.get("inner", None),
282
- H_tfm=self.defaults["H_tfm"],
283
- eigval_fn=self.defaults["eigval_fn"],
284
- use_lstsq=self.defaults["use_lstsq"],
285
- )
250
+ def apply_states(self, objective, states, settings):
251
+ updates = objective.get_updates()
252
+ fs = settings[0]
286
253
 
287
- var.update = vec_to_tensors(update, params)
254
+ b = torch.cat([t.ravel() for t in updates])
255
+ sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
288
256
 
289
- return var
257
+ vec_to_tensors_(sol, updates)
258
+ return objective
290
259
 
291
- def get_H(self,var=...):
292
- return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
260
+ def get_H(self,objective=...):
261
+ return _newton_get_H(self.global_state)
293
262