torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,165 +1,142 @@
1
- from typing import Literal
2
- from collections import abc
3
-
4
- import torch
5
-
6
- from ...utils.derivatives import hessian_list_to_mat, jacobian_and_hessian
7
- from ...tensorlist import TensorList
8
- from ...core import OptimizerModule
9
-
10
-
11
- def _cholesky_solve(hessian: torch.Tensor, grad: torch.Tensor):
12
- cholesky, info = torch.linalg.cholesky_ex(hessian) # pylint:disable=not-callable
13
- if info == 0:
14
- grad.unsqueeze_(1)
15
- return torch.cholesky_solve(grad, cholesky), True
16
- return None, False
17
-
18
- def _lu_solve(hessian: torch.Tensor, grad: torch.Tensor):
19
- try:
20
- newton_step, info = torch.linalg.solve_ex(hessian, grad) # pylint:disable=not-callable
21
- if info == 0: return newton_step, True
22
- return None, False
23
- except torch.linalg.LinAlgError:
24
- return None, False
25
-
26
-
27
- def _cholesky_fallback_lu(hessian: torch.Tensor, grad: torch.Tensor):
28
- step, success = _cholesky_solve(hessian, grad)
29
- if not success:
30
- step, success = _lu_solve(hessian, grad)
31
- return step, success
32
-
33
- def _least_squares_solve(hessian: torch.Tensor, grad: torch.Tensor):
34
- return torch.linalg.lstsq(hessian, grad)[0], True # pylint:disable=not-callable
35
-
36
-
37
- def _fallback_gd(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2):
38
- return grad.mul_(1e-2), True
39
-
40
- def _fallback_safe_diag(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2):
41
- diag = hessian.diag().reciprocal_().nan_to_num_(1,1,1)
42
- if torch.all(diag == 1): # fallback to gd
43
- return _fallback_gd(hessian, grad, lr)
44
- return grad.mul_(diag * lr), True
45
-
46
-
47
- def regularize_hessian_(hessian: torch.Tensor, value: float | Literal['eig']):
48
- """regularize hessian matrix in-place"""
49
- if value == 'eig':
50
- value = torch.linalg.eigvalsh(hessian).min().clamp_(max=0).neg_() # pylint:disable=not-callable
51
- elif value != 0:
52
- hessian.add_(torch.eye(hessian.shape[0], device=hessian.device,dtype=hessian.dtype), alpha = value)
53
-
54
- LinearSystemSolvers = Literal['cholesky', 'lu', 'cholesky_lu', 'lstsq']
55
- FallbackLinearSystemSolvers = Literal['lstsq', 'safe_diag', 'gd']
56
-
57
- LINEAR_SYSTEM_SOLVERS = {
58
- "cholesky": _cholesky_solve,
59
- "lu": _lu_solve,
60
- "cholesky_lu": _cholesky_fallback_lu,
61
- "lstsq": _least_squares_solve,
62
- "safe_diag": _fallback_safe_diag,
63
- "gd": _fallback_gd
64
- }
65
-
66
- class ExactNewton(OptimizerModule):
67
- """Peforms an exact Newton step using batched autograd.
68
-
69
- Note that this doesn't support per-group settings.
70
-
71
- Args:
72
- tikhonov (float, optional):
73
- tikhonov regularization (constant value added to the diagonal of the hessian).
74
- Also known as Levenberg-Marquardt regularization. Can be set to 'eig', so it will be set
75
- to the smallest eigenvalue of the hessian if that value is negative. Defaults to 0.
76
- solver (Solvers, optional):
77
- solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
78
- fallback (Solvers, optional):
79
- what to do if solver fails. Defaults to "safe_diag"
80
- (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
81
- validate (bool, optional):
82
- validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
83
- If not, undo the step and perform a gradient descent step.
84
- tol (float, optional):
85
- only has effect if `validate` is enabled.
86
- If loss increased by `loss * tol`, perform gradient descent step.
87
- Set this to 0 to guarantee that loss always decreases. Defaults to 1.
88
- gd_lr (float, optional):
89
- only has effect if `validate` is enabled.
90
- Gradient descent step learning rate. Defaults to 1e-2.
91
- batched_hessian (bool, optional):
92
- whether to use experimental pytorch vmap-vectorized hessian calculation. As per pytorch docs,
93
- should be faster, but this feature being experimental, there may be performance cliffs.
94
- Defaults to True.
95
- diag (False, optional):
96
- only use the diagonal of the hessian. This will still calculate the full hessian!
97
- This is mainly useful for benchmarking.
98
- """
99
- def __init__(
100
- self,
101
- tikhonov: float | Literal['eig'] = 0.0,
102
- solver: LinearSystemSolvers = "cholesky_lu",
103
- fallback: FallbackLinearSystemSolvers = "safe_diag",
104
- validate=False,
105
- tol: float = 1,
106
- gd_lr = 1e-2,
107
- batched_hessian=True,
108
- diag: bool = False,
109
- ):
110
- super().__init__({})
111
- self.tikhonov: float | Literal['eig'] = tikhonov
112
- self.batched_hessian = batched_hessian
113
-
114
- self.solver: abc.Callable = LINEAR_SYSTEM_SOLVERS[solver]
115
- self.fallback: abc.Callable = LINEAR_SYSTEM_SOLVERS[fallback]
116
-
117
- self.validate = validate
118
- self.gd_lr = gd_lr
119
- self.tol = tol
120
-
121
- self.diag = diag
122
-
123
- @torch.no_grad
124
- def step(self, vars):
125
- if vars.closure is None: raise ValueError("Newton requires a closure to compute the gradient.")
126
-
127
- params = self.get_params()
128
-
129
- # exact hessian via autograd
130
- with torch.enable_grad():
131
- vars.fx0 = vars.closure(False)
132
- grads, hessian = jacobian_and_hessian([vars.fx0], params) # type:ignore
133
- vars.grad = grads = TensorList(grads).squeeze_(0)
134
- gvec = grads.to_vec()
135
- hessian = hessian_list_to_mat(hessian)
136
-
137
- # tikhonov regularization
138
- regularize_hessian_(hessian, self.tikhonov)
139
-
140
- # calculate newton step
141
- if self.diag:
142
- newton_step = gvec / hessian.diag()
143
- else:
144
- newton_step, success = self.solver(hessian, gvec)
145
- if not success:
146
- newton_step, success = self.fallback(hessian, gvec)
147
- if not success:
148
- newton_step, success = _fallback_gd(hessian, gvec)
149
-
150
- # apply the `_update` method
151
- vars.ascent = grads.from_vec(newton_step.squeeze_().nan_to_num_(0,0,0))
152
-
153
- # validate if newton step decreased loss
154
- if self.validate:
155
-
156
- params.sub_(vars.ascent)
157
- fx1 = vars.closure(False)
158
- params.add_(vars.ascent)
159
-
160
- # if loss increases, set ascent direction to grad times lr
161
- if (not fx1.isfinite()) or fx1 - vars.fx0 > vars.fx0 * self.tol: # type:ignore
162
- vars.ascent = grads.div_(grads.total_vector_norm(2) / self.gd_lr)
163
-
164
- # peform an update with the ascent direction, or pass it to the child.
165
- return self._update_params_or_step_with_next(vars, params=params)
1
+ import warnings
2
+ from functools import partial
3
+ from typing import Literal
4
+ from collections.abc import Callable
5
+ import torch
6
+
7
+ from ...core import Chainable, apply, Module
8
+ from ...utils import vec_to_tensors, TensorList
9
+ from ...utils.derivatives import (
10
+ hessian_list_to_mat,
11
+ hessian_mat,
12
+ jacobian_and_hessian_wrt,
13
+ )
14
+
15
+
16
+ def lu_solve(H: torch.Tensor, g: torch.Tensor):
17
+ x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
18
+ if info == 0: return x
19
+ return None
20
+
21
+ def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
22
+ x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
23
+ if info == 0:
24
+ g.unsqueeze_(1)
25
+ return torch.cholesky_solve(g, x)
26
+ return None
27
+
28
+ def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
29
+ return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
30
+
31
+ def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
32
+ try:
33
+ L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
34
+ if tfm is not None: L = tfm(L)
35
+ L.reciprocal_()
36
+ return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
37
+ except torch.linalg.LinAlgError:
38
+ return None
39
+
40
+ def tikhonov_(H: torch.Tensor, reg: float):
41
+ if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
42
+ return H
43
+
44
+ def eig_tikhonov_(H: torch.Tensor, reg: float):
45
+ v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
46
+ return tikhonov_(H, v)
47
+
48
+
49
+ class Newton(Module):
50
+ """Exact newton via autograd.
51
+
52
+ Args:
53
+ reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
54
+ eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
55
+ hessian_method (str):
56
+ how to calculate hessian. Defaults to "autograd".
57
+ vectorize (bool, optional):
58
+ whether to enable vectorized hessian. Defaults to True.
59
+ inner (Chainable | None, optional): inner modules. Defaults to None.
60
+ H_tfm (Callable | None, optional):
61
+ optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
62
+
63
+ must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
64
+ which must be True if transform inverted the hessian and False otherwise. Defaults to None.
65
+ eigval_tfm (Callable | None, optional):
66
+ optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
67
+ If this is specified, eigendecomposition will be used to solve Hx = g.
68
+
69
+ """
70
+ def __init__(
71
+ self,
72
+ reg: float = 1e-6,
73
+ eig_reg: bool = False,
74
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
75
+ vectorize: bool = True,
76
+ inner: Chainable | None = None,
77
+ H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
78
+ eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
79
+ ):
80
+ defaults = dict(reg=reg, eig_reg=eig_reg, abs=abs,hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm)
81
+ super().__init__(defaults)
82
+
83
+ if inner is not None:
84
+ self.set_child('inner', inner)
85
+
86
+ @torch.no_grad
87
+ def step(self, vars):
88
+ params = TensorList(vars.params)
89
+ closure = vars.closure
90
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
91
+
92
+ settings = self.settings[params[0]]
93
+ reg = settings['reg']
94
+ eig_reg = settings['eig_reg']
95
+ hessian_method = settings['hessian_method']
96
+ vectorize = settings['vectorize']
97
+ H_tfm = settings['H_tfm']
98
+ eigval_tfm = settings['eigval_tfm']
99
+
100
+ # ------------------------ calculate grad and hessian ------------------------ #
101
+ if hessian_method == 'autograd':
102
+ with torch.enable_grad():
103
+ loss = vars.loss = vars.loss_approx = closure(False)
104
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
105
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
106
+ vars.grad = g_list
107
+ H = hessian_list_to_mat(H_list)
108
+
109
+ elif hessian_method in ('func', 'autograd.functional'):
110
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
111
+ with torch.enable_grad():
112
+ g_list = vars.get_grad(retain_graph=True)
113
+ H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
114
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
115
+
116
+ else:
117
+ raise ValueError(hessian_method)
118
+
119
+ # -------------------------------- inner step -------------------------------- #
120
+ if 'inner' in self.children:
121
+ g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
122
+ g = torch.cat([t.view(-1) for t in g_list])
123
+
124
+ # ------------------------------- regulazition ------------------------------- #
125
+ if eig_reg: H = eig_tikhonov_(H, reg)
126
+ else: H = tikhonov_(H, reg)
127
+
128
+ # ----------------------------------- solve ---------------------------------- #
129
+ update = None
130
+ if H_tfm is not None:
131
+ H, is_inv = H_tfm(H, g)
132
+ if is_inv: update = H
133
+
134
+ if eigval_tfm is not None:
135
+ update = eigh_solve(H, g, eigval_tfm)
136
+
137
+ if update is None: update = cholesky_solve(H, g)
138
+ if update is None: update = lu_solve(H, g)
139
+ if update is None: update = least_squares_solve(H, g)
140
+
141
+ vars.update = vec_to_tensors(update, params)
142
+ return vars
@@ -0,0 +1,84 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal, overload
3
+ import warnings
4
+ import torch
5
+
6
+ from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
7
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
+
9
+ from ...core import Chainable, apply, Module
10
+ from ...utils.linalg.solve import cg
11
+
12
+ class NewtonCG(Module):
13
+ def __init__(
14
+ self,
15
+ maxiter=None,
16
+ tol=1e-3,
17
+ reg: float = 1e-8,
18
+ hvp_method: Literal["forward", "central", "autograd"] = "forward",
19
+ h=1e-3,
20
+ warm_start=False,
21
+ inner: Chainable | None = None,
22
+ ):
23
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, warm_start=warm_start)
24
+ super().__init__(defaults,)
25
+
26
+ if inner is not None:
27
+ self.set_child('inner', inner)
28
+
29
+ @torch.no_grad
30
+ def step(self, vars):
31
+ params = TensorList(vars.params)
32
+ closure = vars.closure
33
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
34
+
35
+ settings = self.settings[params[0]]
36
+ tol = settings['tol']
37
+ reg = settings['reg']
38
+ maxiter = settings['maxiter']
39
+ hvp_method = settings['hvp_method']
40
+ h = settings['h']
41
+ warm_start = settings['warm_start']
42
+
43
+ # ---------------------- Hessian vector product function --------------------- #
44
+ if hvp_method == 'autograd':
45
+ grad = vars.get_grad(create_graph=True)
46
+
47
+ def H_mm(x):
48
+ with torch.enable_grad():
49
+ return TensorList(hvp(params, grad, x, retain_graph=True))
50
+
51
+ else:
52
+
53
+ with torch.enable_grad():
54
+ grad = vars.get_grad()
55
+
56
+ if hvp_method == 'forward':
57
+ def H_mm(x):
58
+ return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
59
+
60
+ elif hvp_method == 'central':
61
+ def H_mm(x):
62
+ return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
63
+
64
+ else:
65
+ raise ValueError(hvp_method)
66
+
67
+
68
+ # -------------------------------- inner step -------------------------------- #
69
+ b = grad
70
+ if 'inner' in self.children:
71
+ b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
72
+
73
+ # ---------------------------------- run cg ---------------------------------- #
74
+ x0 = None
75
+ if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
76
+ x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
77
+ if warm_start:
78
+ assert x0 is not None
79
+ x0.set_(x)
80
+
81
+ vars.update = x
82
+ return vars
83
+
84
+
@@ -0,0 +1,168 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal, overload
3
+ import warnings
4
+ import torch
5
+
6
+ from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
7
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
+
9
+ from ...core import Chainable, apply, Module
10
+ from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
11
+
12
+ class NystromSketchAndSolve(Module):
13
+ def __init__(
14
+ self,
15
+ rank: int,
16
+ reg: float = 1e-3,
17
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
+ h=1e-3,
19
+ inner: Chainable | None = None,
20
+ seed: int | None = None,
21
+ ):
22
+ defaults = dict(rank=rank, reg=reg, hvp_method=hvp_method, h=h, seed=seed)
23
+ super().__init__(defaults,)
24
+
25
+ if inner is not None:
26
+ self.set_child('inner', inner)
27
+
28
+ @torch.no_grad
29
+ def step(self, vars):
30
+ params = TensorList(vars.params)
31
+
32
+ closure = vars.closure
33
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
34
+
35
+ settings = self.settings[params[0]]
36
+ rank = settings['rank']
37
+ reg = settings['reg']
38
+ hvp_method = settings['hvp_method']
39
+ h = settings['h']
40
+
41
+ seed = settings['seed']
42
+ generator = None
43
+ if seed is not None:
44
+ if 'generator' not in self.global_state:
45
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
46
+ generator = self.global_state['generator']
47
+
48
+ # ---------------------- Hessian vector product function --------------------- #
49
+ if hvp_method == 'autograd':
50
+ grad = vars.get_grad(create_graph=True)
51
+
52
+ def H_mm(x):
53
+ with torch.enable_grad():
54
+ Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
55
+ return torch.cat([t.ravel() for t in Hvp])
56
+
57
+ else:
58
+
59
+ with torch.enable_grad():
60
+ grad = vars.get_grad()
61
+
62
+ if hvp_method == 'forward':
63
+ def H_mm(x):
64
+ Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
65
+ return torch.cat([t.ravel() for t in Hvp])
66
+
67
+ elif hvp_method == 'central':
68
+ def H_mm(x):
69
+ Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
70
+ return torch.cat([t.ravel() for t in Hvp])
71
+
72
+ else:
73
+ raise ValueError(hvp_method)
74
+
75
+
76
+ # -------------------------------- inner step -------------------------------- #
77
+ b = grad
78
+ if 'inner' in self.children:
79
+ b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
80
+
81
+ # ------------------------------ sketch&n&solve ------------------------------ #
82
+ x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
83
+ vars.update = vec_to_tensors(x, reference=params)
84
+ return vars
85
+
86
+
87
+
88
+ class NystromPCG(Module):
89
+ def __init__(
90
+ self,
91
+ sketch_size: int,
92
+ maxiter=None,
93
+ tol=1e-3,
94
+ reg: float = 1e-6,
95
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
96
+ h=1e-3,
97
+ inner: Chainable | None = None,
98
+ seed: int | None = None,
99
+ ):
100
+ defaults = dict(sketch_size=sketch_size, reg=reg, maxiter=maxiter, tol=tol, hvp_method=hvp_method, h=h, seed=seed)
101
+ super().__init__(defaults,)
102
+
103
+ if inner is not None:
104
+ self.set_child('inner', inner)
105
+
106
+ @torch.no_grad
107
+ def step(self, vars):
108
+ params = TensorList(vars.params)
109
+
110
+ closure = vars.closure
111
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
112
+
113
+ settings = self.settings[params[0]]
114
+ sketch_size = settings['sketch_size']
115
+ maxiter = settings['maxiter']
116
+ tol = settings['tol']
117
+ reg = settings['reg']
118
+ hvp_method = settings['hvp_method']
119
+ h = settings['h']
120
+
121
+
122
+ seed = settings['seed']
123
+ generator = None
124
+ if seed is not None:
125
+ if 'generator' not in self.global_state:
126
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
127
+ generator = self.global_state['generator']
128
+
129
+
130
+ # ---------------------- Hessian vector product function --------------------- #
131
+ if hvp_method == 'autograd':
132
+ grad = vars.get_grad(create_graph=True)
133
+
134
+ def H_mm(x):
135
+ with torch.enable_grad():
136
+ Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
137
+ return torch.cat([t.ravel() for t in Hvp])
138
+
139
+ else:
140
+
141
+ with torch.enable_grad():
142
+ grad = vars.get_grad()
143
+
144
+ if hvp_method == 'forward':
145
+ def H_mm(x):
146
+ Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
147
+ return torch.cat([t.ravel() for t in Hvp])
148
+
149
+ elif hvp_method == 'central':
150
+ def H_mm(x):
151
+ Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
152
+ return torch.cat([t.ravel() for t in Hvp])
153
+
154
+ else:
155
+ raise ValueError(hvp_method)
156
+
157
+
158
+ # -------------------------------- inner step -------------------------------- #
159
+ b = grad
160
+ if 'inner' in self.children:
161
+ b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
162
+
163
+ # ------------------------------ sketch&n&solve ------------------------------ #
164
+ x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
165
+ vars.update = vec_to_tensors(x, reference=params)
166
+ return vars
167
+
168
+
@@ -1,5 +1,2 @@
1
- r"""
2
- Gradient smoothing and orthogonalization methods.
3
- """
4
- from .laplacian_smoothing import LaplacianSmoothing, gradient_laplacian_smoothing_
5
- from .gaussian_smoothing import GaussianHomotopy
1
+ from .laplacian import LaplacianSmoothing
2
+ from .gaussian import GaussianHomotopy