torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,150 +1,141 @@
1
- from collections.abc import Callable
2
- from typing import Literal, overload
3
- import warnings
4
- import torch
1
+ from typing import Literal
5
2
 
6
- from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
7
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
3
+ import torch
8
4
 
9
- from ...core import Chainable, apply_transform, Module
10
- from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
5
+ from ...core import Chainable, Transform, HVPMethod
6
+ from ...utils import TensorList, vec_to_tensors
7
+ from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg
8
+ from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
11
9
 
12
- class NystromSketchAndSolve(Module):
10
+ class NystromSketchAndSolve(Transform):
13
11
  """Newton's method with a Nyström sketch-and-solve solver.
14
12
 
15
- .. note::
16
- This module requires the a closure passed to the optimizer step,
17
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
18
- The closure must accept a ``backward`` argument (refer to documentation).
19
-
20
- .. note::
21
- In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
13
+ Notes:
14
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
22
15
 
23
- .. note::
24
- If this is unstable, increase the :code:`reg` parameter and tune the rank.
16
+ - In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
25
17
 
26
- .. note:
27
- :code:`tz.m.NystromPCG` usually outperforms this.
18
+ - If this is unstable, increase the ``reg`` parameter and tune the rank.
28
19
 
29
20
  Args:
30
21
  rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
31
22
  reg (float, optional): regularization parameter. Defaults to 1e-3.
32
23
  hvp_method (str, optional):
33
- Determines how Hessian-vector products are evaluated.
34
-
35
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
- This requires creating a graph for the gradient.
37
- - ``"forward"``: Use a forward finite difference formula to
38
- approximate the HVP. This requires one extra gradient evaluation.
39
- - ``"central"``: Use a central finite difference formula for a
40
- more accurate HVP approximation. This requires two extra
41
- gradient evaluations.
42
- Defaults to "autograd".
43
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
24
+ Determines how Hessian-vector products are computed.
25
+
26
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
27
+ - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
28
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
29
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
30
+
31
+ Defaults to ``"autograd"``.
32
+ h (float, optional):
33
+ The step size for finite difference if ``hvp_method`` is
34
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
44
35
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
45
36
  seed (int | None, optional): seed for random generator. Defaults to None.
46
37
 
38
+
47
39
  Examples:
48
- NystromSketchAndSolve with backtracking line search
40
+ NystromSketchAndSolve with backtracking line search
49
41
 
50
- .. code-block:: python
42
+ ```py
43
+ opt = tz.Modular(
44
+ model.parameters(),
45
+ tz.m.NystromSketchAndSolve(100),
46
+ tz.m.Backtracking()
47
+ )
48
+ ```
51
49
 
52
- opt = tz.Modular(
53
- model.parameters(),
54
- tz.m.NystromSketchAndSolve(10),
55
- tz.m.Backtracking()
56
- )
50
+ Trust region NystromSketchAndSolve
51
+
52
+ ```py
53
+ opt = tz.Modular(
54
+ model.parameters(),
55
+ tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
56
+ )
57
+ ```
58
+
59
+ References:
60
+ - [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
61
+ - [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)
57
62
 
58
- Reference:
59
- Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
60
63
  """
61
64
  def __init__(
62
65
  self,
63
66
  rank: int,
64
67
  reg: float = 1e-3,
65
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
68
+ hvp_method: HVPMethod = "batched_autograd",
66
69
  h: float = 1e-3,
70
+ update_freq: int = 1,
67
71
  inner: Chainable | None = None,
68
72
  seed: int | None = None,
69
73
  ):
70
- defaults = dict(rank=rank, reg=reg, hvp_method=hvp_method, h=h, seed=seed)
71
- super().__init__(defaults,)
72
-
73
- if inner is not None:
74
- self.set_child('inner', inner)
74
+ defaults = locals().copy()
75
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
76
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
75
77
 
76
78
  @torch.no_grad
77
- def step(self, var):
78
- params = TensorList(var.params)
79
-
80
- closure = var.closure
81
- if closure is None: raise RuntimeError('NewtonCG requires closure')
82
-
83
- settings = self.settings[params[0]]
84
- rank = settings['rank']
85
- reg = settings['reg']
86
- hvp_method = settings['hvp_method']
87
- h = settings['h']
88
-
89
- seed = settings['seed']
90
- generator = None
91
- if seed is not None:
92
- if 'generator' not in self.global_state:
93
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
94
- generator = self.global_state['generator']
79
+ def update_states(self, objective, states, settings):
80
+ params = TensorList(objective.params)
81
+ fs = settings[0]
95
82
 
96
83
  # ---------------------- Hessian vector product function --------------------- #
97
- if hvp_method == 'autograd':
98
- grad = var.get_grad(create_graph=True)
99
-
100
- def H_mm(x):
101
- with torch.enable_grad():
102
- Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
103
- return torch.cat([t.ravel() for t in Hvp])
84
+ hvp_method = fs['hvp_method']
85
+ h = fs['h']
86
+ _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
104
87
 
105
- else:
88
+ # ---------------------------------- sketch ---------------------------------- #
89
+ ndim = sum(t.numel() for t in objective.params)
90
+ device = params[0].device
91
+ dtype = params[0].dtype
106
92
 
107
- with torch.enable_grad():
108
- grad = var.get_grad()
93
+ generator = self.get_generator(params[0].device, seed=fs['seed'])
94
+ try:
95
+ L, Q = nystrom_approximation(A_mv=H_mv, A_mm=H_mm, ndim=ndim, rank=fs['rank'],
96
+ dtype=dtype, device=device, generator=generator)
109
97
 
110
- if hvp_method == 'forward':
111
- def H_mm(x):
112
- Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
113
- return torch.cat([t.ravel() for t in Hvp])
98
+ self.global_state["L"] = L
99
+ self.global_state["Q"] = Q
100
+ except torch.linalg.LinAlgError:
101
+ pass
114
102
 
115
- elif hvp_method == 'central':
116
- def H_mm(x):
117
- Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
118
- return torch.cat([t.ravel() for t in Hvp])
103
+ def apply_states(self, objective, states, settings):
104
+ fs = settings[0]
105
+ b = objective.get_updates()
119
106
 
120
- else:
121
- raise ValueError(hvp_method)
107
+ # ----------------------------------- solve ---------------------------------- #
108
+ if "L" not in self.global_state:
109
+ return objective
122
110
 
111
+ L = self.global_state["L"]
112
+ Q = self.global_state["Q"]
113
+ x = nystrom_sketch_and_solve(L=L, Q=Q, b=torch.cat([t.ravel() for t in b]), reg=fs["reg"])
123
114
 
124
- # -------------------------------- inner step -------------------------------- #
125
- b = var.get_update()
126
- if 'inner' in self.children:
127
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
115
+ # -------------------------------- set update -------------------------------- #
116
+ objective.updates = vec_to_tensors(x, reference=objective.params)
117
+ return objective
128
118
 
129
- # ------------------------------ sketch&n&solve ------------------------------ #
130
- x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
131
- var.update = vec_to_tensors(x, reference=params)
132
- return var
119
+ def get_H(self, objective=...):
120
+ if "L" not in self.global_state:
121
+ return ScaledIdentity()
133
122
 
123
+ L = self.global_state["L"]
124
+ Q = self.global_state["Q"]
125
+ return Eigendecomposition(L, Q)
134
126
 
135
127
 
136
- class NystromPCG(Module):
128
+ class NystromPCG(Transform):
137
129
  """Newton's method with a Nyström-preconditioned conjugate gradient solver.
138
130
  This tends to outperform NewtonCG but requires tuning sketch size.
139
131
  An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
140
132
 
141
- .. note::
142
- This module requires the a closure passed to the optimizer step,
133
+ Notes:
134
+ - This module requires the a closure passed to the optimizer step,
143
135
  as it needs to re-evaluate the loss and gradients for calculating HVPs.
144
136
  The closure must accept a ``backward`` argument (refer to documentation).
145
137
 
146
- .. note::
147
- In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
138
+ - In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
148
139
 
149
140
  Args:
150
141
  sketch_size (int):
@@ -159,31 +150,31 @@ class NystromPCG(Module):
159
150
  tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
160
151
  reg (float, optional): regularization parameter. Defaults to 1e-8.
161
152
  hvp_method (str, optional):
162
- Determines how Hessian-vector products are evaluated.
163
-
164
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
165
- This requires creating a graph for the gradient.
166
- - ``"forward"``: Use a forward finite difference formula to
167
- approximate the HVP. This requires one extra gradient evaluation.
168
- - ``"central"``: Use a central finite difference formula for a
169
- more accurate HVP approximation. This requires two extra
170
- gradient evaluations.
171
- Defaults to "autograd".
172
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
153
+ Determines how Hessian-vector products are computed.
154
+
155
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
156
+ - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
157
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
158
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
159
+
160
+ Defaults to ``"autograd"``.
161
+ h (float, optional):
162
+ The step size for finite difference if ``hvp_method`` is
163
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
173
164
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
174
165
  seed (int | None, optional): seed for random generator. Defaults to None.
175
166
 
176
167
  Examples:
177
168
 
178
- NystromPCG with backtracking line search
169
+ NystromPCG with backtracking line search
179
170
 
180
- .. code-block:: python
181
-
182
- opt = tz.Modular(
183
- model.parameters(),
184
- tz.m.NystromPCG(10),
185
- tz.m.Backtracking()
186
- )
171
+ ```python
172
+ opt = tz.Modular(
173
+ model.parameters(),
174
+ tz.m.NystromPCG(10),
175
+ tz.m.Backtracking()
176
+ )
177
+ ```
187
178
 
188
179
  Reference:
189
180
  Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
@@ -191,81 +182,70 @@ class NystromPCG(Module):
191
182
  """
192
183
  def __init__(
193
184
  self,
194
- sketch_size: int,
185
+ rank: int,
195
186
  maxiter=None,
196
187
  tol=1e-8,
197
188
  reg: float = 1e-6,
198
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
189
+ update_freq: int = 1, # here update_freq is within update_states
190
+ hvp_method: HVPMethod = "batched_autograd",
199
191
  h=1e-3,
200
192
  inner: Chainable | None = None,
201
193
  seed: int | None = None,
202
194
  ):
203
- defaults = dict(sketch_size=sketch_size, reg=reg, maxiter=maxiter, tol=tol, hvp_method=hvp_method, h=h, seed=seed)
204
- super().__init__(defaults,)
205
-
206
- if inner is not None:
207
- self.set_child('inner', inner)
195
+ defaults = locals().copy()
196
+ del defaults['self'], defaults['inner']
197
+ super().__init__(defaults, inner=inner)
208
198
 
209
199
  @torch.no_grad
210
- def step(self, var):
211
- params = TensorList(var.params)
212
-
213
- closure = var.closure
214
- if closure is None: raise RuntimeError('NewtonCG requires closure')
215
-
216
- settings = self.settings[params[0]]
217
- sketch_size = settings['sketch_size']
218
- maxiter = settings['maxiter']
219
- tol = settings['tol']
220
- reg = settings['reg']
221
- hvp_method = settings['hvp_method']
222
- h = settings['h']
223
-
224
-
225
- seed = settings['seed']
226
- generator = None
227
- if seed is not None:
228
- if 'generator' not in self.global_state:
229
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
230
- generator = self.global_state['generator']
231
-
200
+ def update_states(self, objective, states, settings):
201
+ fs = settings[0]
232
202
 
233
203
  # ---------------------- Hessian vector product function --------------------- #
234
- if hvp_method == 'autograd':
235
- grad = var.get_grad(create_graph=True)
236
-
237
- def H_mm(x):
238
- with torch.enable_grad():
239
- Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
240
- return torch.cat([t.ravel() for t in Hvp])
241
-
242
- else:
204
+ # this should run on every update_states
205
+ hvp_method = fs['hvp_method']
206
+ h = fs['h']
207
+ _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
208
+ objective.temp = H_mv
243
209
 
244
- with torch.enable_grad():
245
- grad = var.get_grad()
210
+ # --------------------------- update preconditioner -------------------------- #
211
+ step = self.increment_counter("step", 0)
212
+ update_freq = self.defaults["update_freq"]
246
213
 
247
- if hvp_method == 'forward':
248
- def H_mm(x):
249
- Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
250
- return torch.cat([t.ravel() for t in Hvp])
214
+ if step % update_freq == 0:
251
215
 
252
- elif hvp_method == 'central':
253
- def H_mm(x):
254
- Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
255
- return torch.cat([t.ravel() for t in Hvp])
216
+ rank = fs['rank']
217
+ ndim = sum(t.numel() for t in objective.params)
218
+ device = objective.params[0].device
219
+ dtype = objective.params[0].dtype
220
+ generator = self.get_generator(device, seed=fs['seed'])
256
221
 
257
- else:
258
- raise ValueError(hvp_method)
259
-
260
-
261
- # -------------------------------- inner step -------------------------------- #
262
- b = var.get_update()
263
- if 'inner' in self.children:
264
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
265
-
266
- # ------------------------------ sketch&n&solve ------------------------------ #
267
- x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
268
- var.update = vec_to_tensors(x, reference=params)
269
- return var
222
+ try:
223
+ L, Q = nystrom_approximation(A_mv=None, A_mm=H_mm, ndim=ndim, rank=rank,
224
+ dtype=dtype, device=device, generator=generator)
270
225
 
226
+ self.global_state["L"] = L
227
+ self.global_state["Q"] = Q
228
+ except torch.linalg.LinAlgError:
229
+ pass
271
230
 
231
+ @torch.no_grad
232
+ def apply_states(self, objective, states, settings):
233
+ b = objective.get_updates()
234
+ H_mv = objective.poptemp()
235
+ fs = self.settings[objective.params[0]]
236
+
237
+ # ----------------------------------- solve ---------------------------------- #
238
+ if "L" not in self.global_state:
239
+ # fallback on cg
240
+ sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
241
+ objective.updates = sol.x
242
+ return objective
243
+
244
+ L = self.global_state["L"]
245
+ Q = self.global_state["Q"]
246
+ x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
247
+ reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
248
+
249
+ # -------------------------------- set update -------------------------------- #
250
+ objective.updates = vec_to_tensors(x, reference=objective.params)
251
+ return objective