torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,150 +1,141 @@
1
- from collections.abc import Callable
2
- from typing import Literal, overload
3
- import warnings
4
- import torch
1
+ from typing import Literal
5
2
 
6
- from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
7
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
3
+ import torch
8
4
 
9
- from ...core import Chainable, apply_transform, Module
10
- from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
5
+ from ...core import Chainable, Transform, HVPMethod
6
+ from ...utils import TensorList, vec_to_tensors
7
+ from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg
8
+ from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
11
9
 
12
- class NystromSketchAndSolve(Module):
10
+ class NystromSketchAndSolve(Transform):
13
11
  """Newton's method with a Nyström sketch-and-solve solver.
14
12
 
15
- .. note::
16
- This module requires the a closure passed to the optimizer step,
17
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
18
- The closure must accept a ``backward`` argument (refer to documentation).
19
-
20
- .. note::
21
- In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
13
+ Notes:
14
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
22
15
 
23
- .. note::
24
- If this is unstable, increase the :code:`reg` parameter and tune the rank.
16
+ - In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
25
17
 
26
- .. note:
27
- :code:`tz.m.NystromPCG` usually outperforms this.
18
+ - If this is unstable, increase the ``reg`` parameter and tune the rank.
28
19
 
29
20
  Args:
30
21
  rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
31
22
  reg (float, optional): regularization parameter. Defaults to 1e-3.
32
23
  hvp_method (str, optional):
33
- Determines how Hessian-vector products are evaluated.
34
-
35
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
- This requires creating a graph for the gradient.
37
- - ``"forward"``: Use a forward finite difference formula to
38
- approximate the HVP. This requires one extra gradient evaluation.
39
- - ``"central"``: Use a central finite difference formula for a
40
- more accurate HVP approximation. This requires two extra
41
- gradient evaluations.
42
- Defaults to "autograd".
43
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
24
+ Determines how Hessian-vector products are computed.
25
+
26
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
27
+ - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
28
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
29
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
30
+
31
+ Defaults to ``"autograd"``.
32
+ h (float, optional):
33
+ The step size for finite difference if ``hvp_method`` is
34
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
44
35
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
45
36
  seed (int | None, optional): seed for random generator. Defaults to None.
46
37
 
38
+
47
39
  Examples:
48
- NystromSketchAndSolve with backtracking line search
40
+ NystromSketchAndSolve with backtracking line search
49
41
 
50
- .. code-block:: python
42
+ ```py
43
+ opt = tz.Modular(
44
+ model.parameters(),
45
+ tz.m.NystromSketchAndSolve(100),
46
+ tz.m.Backtracking()
47
+ )
48
+ ```
51
49
 
52
- opt = tz.Modular(
53
- model.parameters(),
54
- tz.m.NystromSketchAndSolve(10),
55
- tz.m.Backtracking()
56
- )
50
+ Trust region NystromSketchAndSolve
51
+
52
+ ```py
53
+ opt = tz.Modular(
54
+ model.parameters(),
55
+ tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
56
+ )
57
+ ```
58
+
59
+ References:
60
+ - [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
61
+ - [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)
57
62
 
58
- Reference:
59
- Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
60
63
  """
61
64
  def __init__(
62
65
  self,
63
66
  rank: int,
64
67
  reg: float = 1e-3,
65
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
68
+ hvp_method: HVPMethod = "batched_autograd",
66
69
  h: float = 1e-3,
70
+ update_freq: int = 1,
67
71
  inner: Chainable | None = None,
68
72
  seed: int | None = None,
69
73
  ):
70
- defaults = dict(rank=rank, reg=reg, hvp_method=hvp_method, h=h, seed=seed)
71
- super().__init__(defaults,)
72
-
73
- if inner is not None:
74
- self.set_child('inner', inner)
74
+ defaults = locals().copy()
75
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
76
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
75
77
 
76
78
  @torch.no_grad
77
- def step(self, var):
78
- params = TensorList(var.params)
79
-
80
- closure = var.closure
81
- if closure is None: raise RuntimeError('NewtonCG requires closure')
82
-
83
- settings = self.settings[params[0]]
84
- rank = settings['rank']
85
- reg = settings['reg']
86
- hvp_method = settings['hvp_method']
87
- h = settings['h']
88
-
89
- seed = settings['seed']
90
- generator = None
91
- if seed is not None:
92
- if 'generator' not in self.global_state:
93
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
94
- generator = self.global_state['generator']
79
+ def update_states(self, objective, states, settings):
80
+ params = TensorList(objective.params)
81
+ fs = settings[0]
95
82
 
96
83
  # ---------------------- Hessian vector product function --------------------- #
97
- if hvp_method == 'autograd':
98
- grad = var.get_grad(create_graph=True)
99
-
100
- def H_mm(x):
101
- with torch.enable_grad():
102
- Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
103
- return torch.cat([t.ravel() for t in Hvp])
84
+ hvp_method = fs['hvp_method']
85
+ h = fs['h']
86
+ _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
104
87
 
105
- else:
88
+ # ---------------------------------- sketch ---------------------------------- #
89
+ ndim = sum(t.numel() for t in objective.params)
90
+ device = params[0].device
91
+ dtype = params[0].dtype
106
92
 
107
- with torch.enable_grad():
108
- grad = var.get_grad()
93
+ generator = self.get_generator(params[0].device, seed=fs['seed'])
94
+ try:
95
+ L, Q = nystrom_approximation(A_mv=H_mv, A_mm=H_mm, ndim=ndim, rank=fs['rank'],
96
+ dtype=dtype, device=device, generator=generator)
109
97
 
110
- if hvp_method == 'forward':
111
- def H_mm(x):
112
- Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
113
- return torch.cat([t.ravel() for t in Hvp])
98
+ self.global_state["L"] = L
99
+ self.global_state["Q"] = Q
100
+ except torch.linalg.LinAlgError:
101
+ pass
114
102
 
115
- elif hvp_method == 'central':
116
- def H_mm(x):
117
- Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
118
- return torch.cat([t.ravel() for t in Hvp])
103
+ def apply_states(self, objective, states, settings):
104
+ fs = settings[0]
105
+ b = objective.get_updates()
119
106
 
120
- else:
121
- raise ValueError(hvp_method)
107
+ # ----------------------------------- solve ---------------------------------- #
108
+ if "L" not in self.global_state:
109
+ return objective
122
110
 
111
+ L = self.global_state["L"]
112
+ Q = self.global_state["Q"]
113
+ x = nystrom_sketch_and_solve(L=L, Q=Q, b=torch.cat([t.ravel() for t in b]), reg=fs["reg"])
123
114
 
124
- # -------------------------------- inner step -------------------------------- #
125
- b = var.get_update()
126
- if 'inner' in self.children:
127
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
115
+ # -------------------------------- set update -------------------------------- #
116
+ objective.updates = vec_to_tensors(x, reference=objective.params)
117
+ return objective
128
118
 
129
- # ------------------------------ sketch&n&solve ------------------------------ #
130
- x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
131
- var.update = vec_to_tensors(x, reference=params)
132
- return var
119
+ def get_H(self, objective=...):
120
+ if "L" not in self.global_state:
121
+ return ScaledIdentity()
133
122
 
123
+ L = self.global_state["L"]
124
+ Q = self.global_state["Q"]
125
+ return Eigendecomposition(L, Q)
134
126
 
135
127
 
136
- class NystromPCG(Module):
128
+ class NystromPCG(Transform):
137
129
  """Newton's method with a Nyström-preconditioned conjugate gradient solver.
138
130
  This tends to outperform NewtonCG but requires tuning sketch size.
139
131
  An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
140
132
 
141
- .. note::
142
- This module requires the a closure passed to the optimizer step,
133
+ Notes:
134
+ - This module requires the a closure passed to the optimizer step,
143
135
  as it needs to re-evaluate the loss and gradients for calculating HVPs.
144
136
  The closure must accept a ``backward`` argument (refer to documentation).
145
137
 
146
- .. note::
147
- In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
138
+ - In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
148
139
 
149
140
  Args:
150
141
  sketch_size (int):
@@ -159,31 +150,31 @@ class NystromPCG(Module):
159
150
  tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
160
151
  reg (float, optional): regularization parameter. Defaults to 1e-8.
161
152
  hvp_method (str, optional):
162
- Determines how Hessian-vector products are evaluated.
163
-
164
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
165
- This requires creating a graph for the gradient.
166
- - ``"forward"``: Use a forward finite difference formula to
167
- approximate the HVP. This requires one extra gradient evaluation.
168
- - ``"central"``: Use a central finite difference formula for a
169
- more accurate HVP approximation. This requires two extra
170
- gradient evaluations.
171
- Defaults to "autograd".
172
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
153
+ Determines how Hessian-vector products are computed.
154
+
155
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
156
+ - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
157
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
158
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
159
+
160
+ Defaults to ``"autograd"``.
161
+ h (float, optional):
162
+ The step size for finite difference if ``hvp_method`` is
163
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
173
164
  inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
174
165
  seed (int | None, optional): seed for random generator. Defaults to None.
175
166
 
176
167
  Examples:
177
168
 
178
- NystromPCG with backtracking line search
169
+ NystromPCG with backtracking line search
179
170
 
180
- .. code-block:: python
181
-
182
- opt = tz.Modular(
183
- model.parameters(),
184
- tz.m.NystromPCG(10),
185
- tz.m.Backtracking()
186
- )
171
+ ```python
172
+ opt = tz.Modular(
173
+ model.parameters(),
174
+ tz.m.NystromPCG(10),
175
+ tz.m.Backtracking()
176
+ )
177
+ ```
187
178
 
188
179
  Reference:
189
180
  Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
@@ -191,81 +182,70 @@ class NystromPCG(Module):
191
182
  """
192
183
  def __init__(
193
184
  self,
194
- sketch_size: int,
185
+ rank: int,
195
186
  maxiter=None,
196
- tol=1e-3,
187
+ tol=1e-8,
197
188
  reg: float = 1e-6,
198
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
189
+ update_freq: int = 1, # here update_freq is within update_states
190
+ hvp_method: HVPMethod = "batched_autograd",
199
191
  h=1e-3,
200
192
  inner: Chainable | None = None,
201
193
  seed: int | None = None,
202
194
  ):
203
- defaults = dict(sketch_size=sketch_size, reg=reg, maxiter=maxiter, tol=tol, hvp_method=hvp_method, h=h, seed=seed)
204
- super().__init__(defaults,)
205
-
206
- if inner is not None:
207
- self.set_child('inner', inner)
195
+ defaults = locals().copy()
196
+ del defaults['self'], defaults['inner']
197
+ super().__init__(defaults, inner=inner)
208
198
 
209
199
  @torch.no_grad
210
- def step(self, var):
211
- params = TensorList(var.params)
212
-
213
- closure = var.closure
214
- if closure is None: raise RuntimeError('NewtonCG requires closure')
215
-
216
- settings = self.settings[params[0]]
217
- sketch_size = settings['sketch_size']
218
- maxiter = settings['maxiter']
219
- tol = settings['tol']
220
- reg = settings['reg']
221
- hvp_method = settings['hvp_method']
222
- h = settings['h']
223
-
224
-
225
- seed = settings['seed']
226
- generator = None
227
- if seed is not None:
228
- if 'generator' not in self.global_state:
229
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
230
- generator = self.global_state['generator']
231
-
200
+ def update_states(self, objective, states, settings):
201
+ fs = settings[0]
232
202
 
233
203
  # ---------------------- Hessian vector product function --------------------- #
234
- if hvp_method == 'autograd':
235
- grad = var.get_grad(create_graph=True)
236
-
237
- def H_mm(x):
238
- with torch.enable_grad():
239
- Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
240
- return torch.cat([t.ravel() for t in Hvp])
241
-
242
- else:
204
+ # this should run on every update_states
205
+ hvp_method = fs['hvp_method']
206
+ h = fs['h']
207
+ _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
208
+ objective.temp = H_mv
243
209
 
244
- with torch.enable_grad():
245
- grad = var.get_grad()
210
+ # --------------------------- update preconditioner -------------------------- #
211
+ step = self.increment_counter("step", 0)
212
+ update_freq = self.defaults["update_freq"]
246
213
 
247
- if hvp_method == 'forward':
248
- def H_mm(x):
249
- Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
250
- return torch.cat([t.ravel() for t in Hvp])
214
+ if step % update_freq == 0:
251
215
 
252
- elif hvp_method == 'central':
253
- def H_mm(x):
254
- Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
255
- return torch.cat([t.ravel() for t in Hvp])
216
+ rank = fs['rank']
217
+ ndim = sum(t.numel() for t in objective.params)
218
+ device = objective.params[0].device
219
+ dtype = objective.params[0].dtype
220
+ generator = self.get_generator(device, seed=fs['seed'])
256
221
 
257
- else:
258
- raise ValueError(hvp_method)
259
-
260
-
261
- # -------------------------------- inner step -------------------------------- #
262
- b = var.get_update()
263
- if 'inner' in self.children:
264
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
265
-
266
- # ------------------------------ sketch&n&solve ------------------------------ #
267
- x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
268
- var.update = vec_to_tensors(x, reference=params)
269
- return var
222
+ try:
223
+ L, Q = nystrom_approximation(A_mv=None, A_mm=H_mm, ndim=ndim, rank=rank,
224
+ dtype=dtype, device=device, generator=generator)
270
225
 
226
+ self.global_state["L"] = L
227
+ self.global_state["Q"] = Q
228
+ except torch.linalg.LinAlgError:
229
+ pass
271
230
 
231
+ @torch.no_grad
232
+ def apply_states(self, objective, states, settings):
233
+ b = objective.get_updates()
234
+ H_mv = objective.poptemp()
235
+ fs = self.settings[objective.params[0]]
236
+
237
+ # ----------------------------------- solve ---------------------------------- #
238
+ if "L" not in self.global_state:
239
+ # fallback on cg
240
+ sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
241
+ objective.updates = sol.x
242
+ return objective
243
+
244
+ L = self.global_state["L"]
245
+ Q = self.global_state["Q"]
246
+ x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
247
+ reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
248
+
249
+ # -------------------------------- set update -------------------------------- #
250
+ objective.updates = vec_to_tensors(x, reference=objective.params)
251
+ return objective
@@ -0,0 +1,237 @@
1
+ import math
2
+ from collections import deque
3
+ from collections.abc import Callable
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Transform, HVPMethod
9
+ from ...utils import vec_to_tensors
10
+ from ...linalg.linear_operator import Sketched
11
+
12
+ from .newton import _newton_step
13
+
14
+ def _qr_orthonormalize(A:torch.Tensor):
15
+ m,n = A.shape
16
+ if m < n:
17
+ q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
18
+ return q.T
19
+
20
+ q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
21
+ return q
22
+
23
+ def _orthonormal_sketch(m, n, dtype, device, generator):
24
+ return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
25
+
26
+ def _gaussian_sketch(m, n, dtype, device, generator):
27
+ return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
28
+
29
+ def _rademacher_sketch(m, n, dtype, device, generator):
30
+ rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
31
+ return rademacher.mul_(1 / math.sqrt(m))
32
+
33
+ class SubspaceNewton(Transform):
34
+ """Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
35
+
36
+ Args:
37
+ sketch_size (int):
38
+ size of the random sketch. This many hessian-vector products will need to be evaluated each step.
39
+ sketch_type (str, optional):
40
+ - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
41
+ - "rademacher" - approximately orthonormal scaled random rademacher basis.
42
+ - "gaussian" - random gaussian (not orthonormal) basis.
43
+ - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
44
+ - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction (default).
45
+ damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
46
+ hvp_method (str, optional):
47
+ How to compute hessian-matrix product:
48
+ - "batched_autograd" - uses batched autograd
49
+ - "autograd" - uses unbatched autograd
50
+ - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
51
+ - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
52
+
53
+ . Defaults to "batched_autograd".
54
+ h (float, optional): finite difference step size. Defaults to 1e-2.
55
+ use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
56
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
57
+ H_tfm (Callable | None, optional):
58
+ optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
59
+
60
+ must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
61
+ which must be True if transform inverted the hessian and False otherwise.
62
+
63
+ Or it returns a single tensor which is used as the update.
64
+
65
+ Defaults to None.
66
+ eigval_fn (Callable | None, optional):
67
+ optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
68
+ If this is specified, eigendecomposition will be used to invert the hessian.
69
+ seed (int | None, optional): seed for random generator. Defaults to None.
70
+ inner (Chainable | None, optional): preconditions output of this module. Defaults to None.
71
+
72
+ ### Examples
73
+
74
+ RSN with line search
75
+ ```python
76
+ opt = tz.Modular(
77
+ model.parameters(),
78
+ tz.m.RSN(),
79
+ tz.m.Backtracking()
80
+ )
81
+ ```
82
+
83
+ RSN with trust region
84
+ ```python
85
+ opt = tz.Modular(
86
+ model.parameters(),
87
+ tz.m.LevenbergMarquardt(tz.m.RSN()),
88
+ )
89
+ ```
90
+
91
+
92
+ References:
93
+ 1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
94
+ 2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ sketch_size: int,
100
+ sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
101
+ damping:float=0,
102
+ hvp_method: HVPMethod = "batched_autograd",
103
+ h: float = 1e-2,
104
+ use_lstsq: bool = True,
105
+ update_freq: int = 1,
106
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
107
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
108
+ seed: int | None = None,
109
+ inner: Chainable | None = None,
110
+ ):
111
+ defaults = locals().copy()
112
+ del defaults['self'], defaults['inner'], defaults["update_freq"]
113
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
114
+
115
+ @torch.no_grad
116
+ def update_states(self, objective, states, settings):
117
+ fs = settings[0]
118
+ params = objective.params
119
+ generator = self.get_generator(params[0].device, fs["seed"])
120
+
121
+ ndim = sum(p.numel() for p in params)
122
+
123
+ device=params[0].device
124
+ dtype=params[0].dtype
125
+
126
+ # sample sketch matrix S: (ndim, sketch_size)
127
+ sketch_size = min(fs["sketch_size"], ndim)
128
+ sketch_type = fs["sketch_type"]
129
+ hvp_method = fs["hvp_method"]
130
+
131
+ if sketch_type in ('normal', 'gaussian'):
132
+ S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
133
+
134
+ elif sketch_type == "rademacher":
135
+ S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
136
+
137
+ elif sketch_type == 'orthonormal':
138
+ S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
139
+
140
+ elif sketch_type == 'common_directions':
141
+ # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
142
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
143
+ g = torch.cat([t.ravel() for t in g_list])
144
+
145
+ # initialize directions deque
146
+ if "directions" not in self.global_state:
147
+
148
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
149
+ if g_norm < torch.finfo(g.dtype).tiny * 2:
150
+ g = torch.randn_like(g)
151
+ g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
152
+
153
+ self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
154
+ S = self.global_state["directions"][0].unsqueeze(1)
155
+
156
+ # add new steepest descent direction orthonormal to existing columns
157
+ else:
158
+ S = torch.stack(tuple(self.global_state["directions"]), dim=1)
159
+ p = g - S @ (S.T @ g)
160
+ p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
161
+ if p_norm > torch.finfo(p.dtype).tiny * 2:
162
+ p = p / p_norm
163
+ self.global_state["directions"].append(p)
164
+ S = torch.cat([S, p.unsqueeze(1)], dim=1)
165
+
166
+ elif sketch_type == "mixed":
167
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
168
+ g = torch.cat([t.ravel() for t in g_list])
169
+
170
+ # initialize state
171
+ if "slow_ema" not in self.global_state:
172
+ self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
173
+ self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
174
+ self.global_state["p_prev"] = torch.randn_like(g)
175
+
176
+ # previous update direction
177
+ p_cur = torch.cat([t.ravel() for t in params])
178
+ prev_dir = p_cur - self.global_state["p_prev"]
179
+ self.global_state["p_prev"] = p_cur
180
+
181
+ # EMAs
182
+ slow_ema = self.global_state["slow_ema"]
183
+ fast_ema = self.global_state["fast_ema"]
184
+ slow_ema.lerp_(g, 0.001)
185
+ fast_ema.lerp_(g, 0.1)
186
+
187
+ # form and orthogonalize sketching matrix
188
+ S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
189
+ if sketch_size > 4:
190
+ S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
191
+ S = torch.cat([S, S_random], dim=1)
192
+
193
+ S = _qr_orthonormalize(S)
194
+
195
+ else:
196
+ raise ValueError(f'Unknown sketch_type {sketch_type}')
197
+
198
+ # form sketched hessian
199
+ HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
200
+ hvp_method=fs["hvp_method"], h=fs["h"])
201
+ H_sketched = S.T @ HS
202
+
203
+ self.global_state["H_sketched"] = H_sketched
204
+ self.global_state["S"] = S
205
+
206
+ def apply_states(self, objective, states, settings):
207
+ S: torch.Tensor = self.global_state["S"]
208
+
209
+ d_proj = _newton_step(
210
+ objective=objective,
211
+ H=self.global_state["H_sketched"],
212
+ damping=self.defaults["damping"],
213
+ H_tfm=self.defaults["H_tfm"],
214
+ eigval_fn=self.defaults["eigval_fn"],
215
+ use_lstsq=self.defaults["use_lstsq"],
216
+ g_proj = lambda g: S.T @ g
217
+ )
218
+
219
+ d = S @ d_proj
220
+ objective.updates = vec_to_tensors(d, objective.params)
221
+ return objective
222
+
223
+ def get_H(self, objective=...):
224
+ eigval_fn = self.defaults["eigval_fn"]
225
+ H_sketched: torch.Tensor = self.global_state["H_sketched"]
226
+ S: torch.Tensor = self.global_state["S"]
227
+
228
+ if eigval_fn is not None:
229
+ try:
230
+ L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
231
+ L: torch.Tensor = eigval_fn(L)
232
+ H_sketched = Q @ L.diag_embed() @ Q.mH
233
+
234
+ except torch.linalg.LinAlgError:
235
+ pass
236
+
237
+ return Sketched(S, H_sketched)