torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,153 +1,187 @@
1
- from collections.abc import Callable
2
- from typing import Literal, overload
3
1
  import warnings
4
- import torch
2
+ from typing import Literal
5
3
 
6
- from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
7
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
4
+ import torch
8
5
 
9
- from ...core import Chainable, apply_transform, Module
10
- from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
6
+ from ...core import Chainable, Transform, HVPMethod
7
+ from ...utils import TensorList, vec_to_tensors
8
+ from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod
9
+ from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
11
10
 
12
- class NystromSketchAndSolve(Module):
11
+ class NystromSketchAndSolve(Transform):
13
12
  """Newton's method with a Nyström sketch-and-solve solver.
14
13
 
15
- .. note::
16
- This module requires the a closure passed to the optimizer step,
17
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
18
- The closure must accept a ``backward`` argument (refer to documentation).
19
-
20
- .. note::
21
- In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
14
+ Notes:
15
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
22
16
 
23
- .. note::
24
- If this is unstable, increase the :code:`reg` parameter and tune the rank.
17
+ - In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
25
18
 
26
- .. note:
27
- :code:`tz.m.NystromPCG` usually outperforms this.
19
+ - If this is unstable, increase the ``reg`` parameter and tune the rank.
28
20
 
29
21
  Args:
30
22
  rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
31
- reg (float, optional): regularization parameter. Defaults to 1e-3.
23
+ reg (float | None, optional):
24
+ scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve
25
+ is used to compute ``(Q diag(L) Q.T + reg*I)x = b``. It is very unstable when ``reg`` is small,
26
+ i.e. smaller than 1e-4. If this is None,``(Q diag(L) Q.T)x = b`` is computed by simply taking
27
+ reciprocal of eigenvalues. Defaults to 1e-3.
28
+ eigv_tol (float, optional):
29
+ all eigenvalues smaller than largest eigenvalue times ``eigv_tol`` are removed. Defaults to None.
30
+ truncate (int | None, optional):
31
+ keeps top ``truncate`` eigenvalues. Defaults to None.
32
+ damping (float, optional): scalar added to eigenvalues. Defaults to 0.
33
+ rdamping (float, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
34
+ update_freq (int, optional): frequency of updating preconditioner. Defaults to 1.
32
35
  hvp_method (str, optional):
33
- Determines how Hessian-vector products are evaluated.
34
-
35
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
- This requires creating a graph for the gradient.
37
- - ``"forward"``: Use a forward finite difference formula to
38
- approximate the HVP. This requires one extra gradient evaluation.
39
- - ``"central"``: Use a central finite difference formula for a
40
- more accurate HVP approximation. This requires two extra
41
- gradient evaluations.
42
- Defaults to "autograd".
43
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
36
+ Determines how Hessian-vector products are computed.
37
+
38
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
39
+ - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
40
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
41
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
42
+
43
+ Defaults to ``"autograd"``.
44
+ h (float, optional):
45
+ The step size for finite difference if ``hvp_method`` is
46
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
44
47
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
45
48
  seed (int | None, optional): seed for random generator. Defaults to None.
46
49
 
50
+
47
51
  Examples:
48
- NystromSketchAndSolve with backtracking line search
52
+ NystromSketchAndSolve with backtracking line search
49
53
 
50
- .. code-block:: python
54
+ ```py
55
+ opt = tz.Optimizer(
56
+ model.parameters(),
57
+ tz.m.NystromSketchAndSolve(100),
58
+ tz.m.Backtracking()
59
+ )
60
+ ```
51
61
 
52
- opt = tz.Modular(
53
- model.parameters(),
54
- tz.m.NystromSketchAndSolve(10),
55
- tz.m.Backtracking()
56
- )
62
+ Trust region NystromSketchAndSolve
63
+
64
+ ```py
65
+ opt = tz.Optimizer(
66
+ model.parameters(),
67
+ tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
68
+ )
69
+ ```
70
+
71
+ References:
72
+ - [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
73
+ - [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)
57
74
 
58
- Reference:
59
- Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
60
75
  """
61
76
  def __init__(
62
77
  self,
63
78
  rank: int,
64
- reg: float = 1e-3,
65
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
79
+ reg: float | None = 1e-2,
80
+ eigv_tol: float = 0,
81
+ truncate: int | None = None,
82
+ damping: float = 0,
83
+ rdamping: float = 0,
84
+ update_freq: int = 1,
85
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
86
+ hvp_method: HVPMethod = "batched_autograd",
66
87
  h: float = 1e-3,
67
88
  inner: Chainable | None = None,
68
89
  seed: int | None = None,
69
90
  ):
70
- defaults = dict(rank=rank, reg=reg, hvp_method=hvp_method, h=h, seed=seed)
71
- super().__init__(defaults,)
72
-
73
- if inner is not None:
74
- self.set_child('inner', inner)
91
+ defaults = locals().copy()
92
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
93
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
75
94
 
76
95
  @torch.no_grad
77
- def step(self, var):
78
- params = TensorList(var.params)
79
-
80
- closure = var.closure
81
- if closure is None: raise RuntimeError('NewtonCG requires closure')
82
-
83
- settings = self.settings[params[0]]
84
- rank = settings['rank']
85
- reg = settings['reg']
86
- hvp_method = settings['hvp_method']
87
- h = settings['h']
88
-
89
- seed = settings['seed']
90
- generator = None
91
- if seed is not None:
92
- if 'generator' not in self.global_state:
93
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
94
- generator = self.global_state['generator']
96
+ def update_states(self, objective, states, settings):
97
+ params = TensorList(objective.params)
98
+ fs = settings[0]
95
99
 
96
100
  # ---------------------- Hessian vector product function --------------------- #
97
- if hvp_method == 'autograd':
98
- grad = var.get_grad(create_graph=True)
101
+ hvp_method = fs['hvp_method']
102
+ h = fs['h']
103
+ _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
104
+
105
+ # ---------------------------------- sketch ---------------------------------- #
106
+ ndim = sum(t.numel() for t in objective.params)
107
+ device = params[0].device
108
+ dtype = params[0].dtype
109
+
110
+ generator = self.get_generator(params[0].device, seed=fs['seed'])
111
+ try:
112
+ # compute the approximation
113
+ L, Q = nystrom_approximation(
114
+ A_mv=H_mv,
115
+ A_mm=H_mm,
116
+ ndim=ndim,
117
+ rank=min(fs["rank"], ndim),
118
+ eigv_tol=fs["eigv_tol"],
119
+ orthogonalize_method=fs["orthogonalize_method"],
120
+ dtype=dtype,
121
+ device=device,
122
+ generator=generator,
123
+ )
99
124
 
100
- def H_mm(x):
101
- with torch.enable_grad():
102
- Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
103
- return torch.cat([t.ravel() for t in Hvp])
125
+ # regularize
126
+ L, Q = regularize_eigh(
127
+ L=L,
128
+ Q=Q,
129
+ truncate=fs["truncate"],
130
+ tol=fs["eigv_tol"],
131
+ damping=fs["damping"],
132
+ rdamping=fs["rdamping"],
133
+ )
104
134
 
105
- else:
135
+ # store
136
+ if L is not None:
137
+ self.global_state["L"] = L
138
+ self.global_state["Q"] = Q
106
139
 
107
- with torch.enable_grad():
108
- grad = var.get_grad()
140
+ except torch.linalg.LinAlgError as e:
141
+ warnings.warn(f"Nystrom approximation failed with: {e}")
109
142
 
110
- if hvp_method == 'forward':
111
- def H_mm(x):
112
- Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
113
- return torch.cat([t.ravel() for t in Hvp])
143
+ def apply_states(self, objective, states, settings):
144
+ if "L" not in self.global_state:
145
+ return objective
114
146
 
115
- elif hvp_method == 'central':
116
- def H_mm(x):
117
- Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
118
- return torch.cat([t.ravel() for t in Hvp])
147
+ fs = settings[0]
148
+ updates = objective.get_updates()
149
+ b=torch.cat([t.ravel() for t in updates])
119
150
 
120
- else:
121
- raise ValueError(hvp_method)
151
+ # ----------------------------------- solve ---------------------------------- #
152
+ L = self.global_state["L"]
153
+ Q = self.global_state["Q"]
122
154
 
155
+ if fs["reg"] is None:
156
+ x = Q @ ((Q.mH @ b) / L)
157
+ else:
158
+ x = nystrom_sketch_and_solve(L=L, Q=Q, b=b, reg=fs["reg"])
123
159
 
124
- # -------------------------------- inner step -------------------------------- #
125
- b = var.get_update()
126
- if 'inner' in self.children:
127
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
160
+ # -------------------------------- set update -------------------------------- #
161
+ objective.updates = vec_to_tensors(x, reference=objective.params)
162
+ return objective
128
163
 
129
- # ------------------------------ sketch&n&solve ------------------------------ #
130
- x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
131
- var.update = vec_to_tensors(x, reference=params)
132
- return var
164
+ def get_H(self, objective=...):
165
+ if "L" not in self.global_state:
166
+ return ScaledIdentity()
133
167
 
168
+ L = self.global_state["L"]
169
+ Q = self.global_state["Q"]
170
+ return Eigendecomposition(L, Q)
134
171
 
135
172
 
136
- class NystromPCG(Module):
173
+ class NystromPCG(Transform):
137
174
  """Newton's method with a Nyström-preconditioned conjugate gradient solver.
138
- This tends to outperform NewtonCG but requires tuning sketch size.
139
- An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
140
175
 
141
- .. note::
142
- This module requires the a closure passed to the optimizer step,
176
+ Notes:
177
+ - This module requires the a closure passed to the optimizer step,
143
178
  as it needs to re-evaluate the loss and gradients for calculating HVPs.
144
179
  The closure must accept a ``backward`` argument (refer to documentation).
145
180
 
146
- .. note::
147
- In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
181
+ - In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
148
182
 
149
183
  Args:
150
- sketch_size (int):
184
+ rank (int):
151
185
  size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
152
186
  running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
153
187
  conjugate gradient.
@@ -159,31 +193,31 @@ class NystromPCG(Module):
159
193
  tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
160
194
  reg (float, optional): regularization parameter. Defaults to 1e-8.
161
195
  hvp_method (str, optional):
162
- Determines how Hessian-vector products are evaluated.
163
-
164
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
165
- This requires creating a graph for the gradient.
166
- - ``"forward"``: Use a forward finite difference formula to
167
- approximate the HVP. This requires one extra gradient evaluation.
168
- - ``"central"``: Use a central finite difference formula for a
169
- more accurate HVP approximation. This requires two extra
170
- gradient evaluations.
171
- Defaults to "autograd".
172
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
196
+ Determines how Hessian-vector products are computed.
197
+
198
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
199
+ - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
200
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
201
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
202
+
203
+ Defaults to ``"autograd"``.
204
+ h (float, optional):
205
+ The step size for finite difference if ``hvp_method`` is
206
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
173
207
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
174
208
  seed (int | None, optional): seed for random generator. Defaults to None.
175
209
 
176
210
  Examples:
177
211
 
178
- NystromPCG with backtracking line search
179
-
180
- .. code-block:: python
212
+ NystromPCG with backtracking line search
181
213
 
182
- opt = tz.Modular(
183
- model.parameters(),
184
- tz.m.NystromPCG(10),
185
- tz.m.Backtracking()
186
- )
214
+ ```python
215
+ opt = tz.Optimizer(
216
+ model.parameters(),
217
+ tz.m.NystromPCG(10),
218
+ tz.m.Backtracking()
219
+ )
220
+ ```
187
221
 
188
222
  Reference:
189
223
  Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
@@ -191,81 +225,78 @@ class NystromPCG(Module):
191
225
  """
192
226
  def __init__(
193
227
  self,
194
- sketch_size: int,
228
+ rank: int,
195
229
  maxiter=None,
196
230
  tol=1e-8,
197
231
  reg: float = 1e-6,
198
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
232
+ update_freq: int = 1, # here update_freq is within update_states
233
+ eigv_tol: float = 0,
234
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
235
+ hvp_method: HVPMethod = "batched_autograd",
199
236
  h=1e-3,
200
237
  inner: Chainable | None = None,
201
238
  seed: int | None = None,
202
239
  ):
203
- defaults = dict(sketch_size=sketch_size, reg=reg, maxiter=maxiter, tol=tol, hvp_method=hvp_method, h=h, seed=seed)
204
- super().__init__(defaults,)
205
-
206
- if inner is not None:
207
- self.set_child('inner', inner)
240
+ defaults = locals().copy()
241
+ del defaults['self'], defaults['inner']
242
+ super().__init__(defaults, inner=inner)
208
243
 
209
244
  @torch.no_grad
210
- def step(self, var):
211
- params = TensorList(var.params)
212
-
213
- closure = var.closure
214
- if closure is None: raise RuntimeError('NewtonCG requires closure')
215
-
216
- settings = self.settings[params[0]]
217
- sketch_size = settings['sketch_size']
218
- maxiter = settings['maxiter']
219
- tol = settings['tol']
220
- reg = settings['reg']
221
- hvp_method = settings['hvp_method']
222
- h = settings['h']
223
-
224
-
225
- seed = settings['seed']
226
- generator = None
227
- if seed is not None:
228
- if 'generator' not in self.global_state:
229
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
230
- generator = self.global_state['generator']
231
-
245
+ def update_states(self, objective, states, settings):
246
+ fs = settings[0]
232
247
 
233
248
  # ---------------------- Hessian vector product function --------------------- #
234
- if hvp_method == 'autograd':
235
- grad = var.get_grad(create_graph=True)
236
-
237
- def H_mm(x):
238
- with torch.enable_grad():
239
- Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
240
- return torch.cat([t.ravel() for t in Hvp])
241
-
242
- else:
243
-
244
- with torch.enable_grad():
245
- grad = var.get_grad()
246
-
247
- if hvp_method == 'forward':
248
- def H_mm(x):
249
- Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
250
- return torch.cat([t.ravel() for t in Hvp])
251
-
252
- elif hvp_method == 'central':
253
- def H_mm(x):
254
- Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
255
- return torch.cat([t.ravel() for t in Hvp])
256
-
257
- else:
258
- raise ValueError(hvp_method)
259
-
260
-
261
- # -------------------------------- inner step -------------------------------- #
262
- b = var.get_update()
263
- if 'inner' in self.children:
264
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
265
-
266
- # ------------------------------ sketch&n&solve ------------------------------ #
267
- x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
268
- var.update = vec_to_tensors(x, reference=params)
269
- return var
270
-
249
+ # this should run on every update_states
250
+ _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=fs['hvp_method'], h=fs['h'], at_x0=True)
251
+ objective.temp = H_mv
252
+
253
+ # --------------------------- update preconditioner -------------------------- #
254
+ step = self.increment_counter("step", 0)
255
+ if step % fs["update_freq"] == 0:
256
+
257
+ ndim = sum(t.numel() for t in objective.params)
258
+ device = objective.params[0].device
259
+ dtype = objective.params[0].dtype
260
+ generator = self.get_generator(device, seed=fs['seed'])
261
+
262
+ try:
263
+ L, Q = nystrom_approximation(
264
+ A_mv=None,
265
+ A_mm=H_mm,
266
+ ndim=ndim,
267
+ rank=min(fs["rank"], ndim),
268
+ eigv_tol=fs["eigv_tol"],
269
+ orthogonalize_method=fs["orthogonalize_method"],
270
+ dtype=dtype,
271
+ device=device,
272
+ generator=generator,
273
+ )
274
+
275
+ self.global_state["L"] = L
276
+ self.global_state["Q"] = Q
277
+
278
+ except torch.linalg.LinAlgError as e:
279
+ warnings.warn(f"Nystrom approximation failed with: {e}")
271
280
 
281
+ @torch.no_grad
282
+ def apply_states(self, objective, states, settings):
283
+ b = objective.get_updates()
284
+ H_mv = objective.poptemp()
285
+ fs = self.settings[objective.params[0]]
286
+
287
+ # ----------------------------------- solve ---------------------------------- #
288
+ if "L" not in self.global_state:
289
+ # fallback on cg
290
+ sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
291
+ objective.updates = sol.x
292
+ return objective
293
+
294
+ L = self.global_state["L"]
295
+ Q = self.global_state["Q"]
296
+
297
+ x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
298
+ reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
299
+
300
+ # -------------------------------- set update -------------------------------- #
301
+ objective.updates = vec_to_tensors(x, reference=objective.params)
302
+ return objective