torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,41 +1,18 @@
1
- """This submodule contains various untested experimental modules, some of them are to be moved out of experimental when properly tested, some are to remain here forever or to be deleted depending on the degree of their usefulness."""
2
- from .absoap import ABSOAP
3
- from .adadam import Adadam
4
- from .adam_lambertw import AdamLambertW
5
- from .adamY import AdamY
6
- from .adaptive_step_size import AdaptiveStepSize
7
- from .adasoap import AdaSOAP
8
- from .cosine import (
9
- AdaptiveDifference,
10
- AdaptiveDifferenceEMA,
11
- CosineDebounce,
12
- CosineMomentum,
13
- CosineStepSize,
14
- ScaledAdaptiveDifference,
15
- )
16
- from .cubic_adam import CubicAdam
1
+ """Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
17
2
  from .curveball import CurveBall
18
3
 
19
4
  # from dct import DCTProjection
20
- from .eigendescent import EigenDescent
21
- from .etf import (
22
- ExponentialTrajectoryFit,
23
- ExponentialTrajectoryFitV2,
24
- PointwiseExponential,
25
- )
26
- from .exp_adam import ExpAdam
27
- from .expanded_lbfgs import ExpandedLBFGS
28
5
  from .fft import FFTProjection
29
6
  from .gradmin import GradMin
30
- from .hnewton import HNewton
31
- from .modular_lbfgs import ModularLBFGS
7
+ from .l_infinity import InfinityNormTrustRegion
8
+ from .momentum import (
9
+ CoordinateMomentum,
10
+ NesterovEMASquared,
11
+ PrecenteredEMASquared,
12
+ SqrtNesterovEMASquared,
13
+ )
32
14
  from .newton_solver import NewtonSolver
33
15
  from .newtonnewton import NewtonNewton
34
- from .parabolic_search import CubicParabolaSearch, ParabolaSearch
35
16
  from .reduce_outward_lr import ReduceOutwardLR
17
+ from .scipy_newton_cg import ScipyNewtonCG
36
18
  from .structural_projections import BlockPartition, TensorizeProjection
37
- from .subspace_preconditioners import (
38
- HistorySubspacePreconditioning,
39
- RandomSubspacePreconditioning,
40
- )
41
- from .tensor_adagrad import TensorAdagrad
@@ -54,8 +54,8 @@ class DCTProjection(ProjectionBase):
54
54
  return projected
55
55
 
56
56
  @torch.no_grad
57
- def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
58
- settings = projected_settings[0]
57
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
58
+ settings = settings[0]
59
59
  dims = settings['dims']
60
60
  norm = settings['norm']
61
61
 
@@ -60,8 +60,8 @@ class FFTProjection(ProjectionBase):
60
60
  return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
61
61
 
62
62
  @torch.no_grad
63
- def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
64
- settings = projected_settings[0]
63
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
64
+ settings = settings[0]
65
65
  one_d = settings['one_d']
66
66
  norm = settings['norm']
67
67
 
@@ -5,11 +5,11 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var
8
+ from ...core import Module, Var, Chainable
9
9
  from ...utils import NumberList, TensorList
10
10
  from ...utils.derivatives import jacobian_wrt
11
11
  from ..grad_approximation import GradApproximator, GradTarget
12
- from ..smoothing.gaussian import Reformulation
12
+ from ..smoothing.sampling import Reformulation
13
13
 
14
14
 
15
15
 
@@ -28,6 +28,7 @@ class GradMin(Reformulation):
28
28
  """
29
29
  def __init__(
30
30
  self,
31
+ modules: Chainable,
31
32
  loss_term: float | None = 0,
32
33
  relative: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
33
34
  graft: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
@@ -39,7 +40,7 @@ class GradMin(Reformulation):
39
40
  ):
40
41
  if (relative is not None) and (graft is not None): warnings.warn('both relative and graft loss are True, they will clash with each other')
41
42
  defaults = dict(loss_term=loss_term, relative=relative, graft=graft, square=square, mean=mean, maximize_grad=maximize_grad, create_graph=create_graph, modify_loss=modify_loss)
42
- super().__init__(defaults)
43
+ super().__init__(defaults, modules=modules)
43
44
 
44
45
  @torch.no_grad
45
46
  def closure(self, backward, closure, params, var):
@@ -0,0 +1,111 @@
1
+
2
+ import numpy as np
3
+ import torch
4
+ from scipy.optimize import lsq_linear
5
+
6
+ from ...core import Chainable, Module
7
+ from ..trust_region.trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
8
+
9
+
10
+ class InfinityNormTrustRegion(TrustRegionBase):
11
+ """Trust region with L-infinity norm via ``scipy.optimize.lsq_linear``.
12
+
13
+ Args:
14
+ hess_module (Module | None, optional):
15
+ A module that maintains a hessian approximation (not hessian inverse!).
16
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
17
+ When using quasi-newton methods, set `inverse=False` when constructing them.
18
+ eta (float, optional):
19
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
20
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
21
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
22
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
23
+ rho_good (float, optional):
24
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
25
+ rho_bad (float, optional):
26
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
27
+ init (float, optional): Initial trust region value. Defaults to 1.
28
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
29
+ max_attempts (max_attempts, optional):
30
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
31
+ this limit is exceeded. Defaults to 10.
32
+ boundary_tol (float | None, optional):
33
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
34
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
35
+ tol (float | None, optional): tolerance for least squares solver.
36
+ fallback (bool, optional):
37
+ if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
38
+ be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
39
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
40
+
41
+ Examples:
42
+ BFGS with infinity-norm trust region
43
+
44
+ .. code-block:: python
45
+
46
+ opt = tz.Modular(
47
+ model.parameters(),
48
+ tz.m.InfinityNormTrustRegion(hess_module=tz.m.BFGS(inverse=False)),
49
+ )
50
+ """
51
+ def __init__(
52
+ self,
53
+ hess_module: Module,
54
+ prefer_dense:bool=True,
55
+ tol: float = 1e-10,
56
+ eta: float= 0.0,
57
+ nplus: float = 3.5,
58
+ nminus: float = 0.25,
59
+ rho_good: float = 0.99,
60
+ rho_bad: float = 1e-4,
61
+ boundary_tol: float | None = None,
62
+ init: float = 1,
63
+ max_attempts: int = 10,
64
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
65
+ update_freq: int = 1,
66
+ inner: Chainable | None = None,
67
+ ):
68
+ defaults = dict(tol=tol, prefer_dense=prefer_dense)
69
+ super().__init__(
70
+ defaults=defaults,
71
+ hess_module=hess_module,
72
+ eta=eta,
73
+ nplus=nplus,
74
+ nminus=nminus,
75
+ rho_good=rho_good,
76
+ rho_bad=rho_bad,
77
+ boundary_tol=boundary_tol,
78
+ init=init,
79
+ max_attempts=max_attempts,
80
+ radius_strategy=radius_strategy,
81
+ update_freq=update_freq,
82
+ inner=inner,
83
+
84
+ radius_fn=torch.amax,
85
+ )
86
+
87
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
88
+ if settings['prefer_dense'] and H.is_dense():
89
+ # convert to array if possible to avoid many conversions
90
+ # between torch and numpy, plus it seems that it uses
91
+ # a better solver
92
+ A = H.to_tensor().numpy(force=True).astype(np.float64)
93
+ else:
94
+ # memory efficient linear operator (is this still faster on CUDA?)
95
+ A = H.scipy_linop()
96
+
97
+ try:
98
+ d_np = lsq_linear(
99
+ A,
100
+ g.numpy(force=True).astype(np.float64),
101
+ tol=settings['bounds'],
102
+ bounds=(-radius, radius),
103
+ ).x
104
+ return torch.as_tensor(d_np, device=g.device, dtype=g.dtype)
105
+
106
+ except np.linalg.LinAlgError:
107
+ self.children['hess_module'].reset()
108
+ g_max = g.amax()
109
+ if g_max > radius:
110
+ g = g * (radius / g_max)
111
+ return g
@@ -6,10 +6,10 @@ from typing import Literal
6
6
  import torch
7
7
 
8
8
  from ...core import Target, Transform
9
- from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
9
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
10
10
  from ..functional import ema_, ema_sq_, sqrt_ema_sq_
11
- from .ema import EMASquared, SqrtEMASquared
12
- from .momentum import nag_
11
+ from ..momentum.momentum import nag_
12
+ from ..ops.higher_level import EMASquared, SqrtEMASquared
13
13
 
14
14
 
15
15
  def precentered_ema_sq_(
@@ -158,40 +158,3 @@ class CoordinateMomentum(Transform):
158
158
  p = NumberList(s['p'] for s in settings)
159
159
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
160
160
  return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
161
-
162
-
163
- # def multiplicative_momentum_(
164
- # tensors_: TensorList,
165
- # velocity_: TensorList,
166
- # momentum: float | NumberList,
167
- # dampening: float | NumberList,
168
- # normalize_velocity: bool = True,
169
- # abs: bool = False,
170
- # lerp: bool = False,
171
- # ):
172
- # """
173
- # abs: if True, tracks momentum of absolute magnitudes.
174
-
175
- # returns `tensors_`.
176
- # """
177
- # tensors_into_velocity = tensors_.abs() if abs else tensors_
178
- # ema_(tensors_into_velocity, exp_avg_=velocity_, beta=momentum, dampening=0, lerp=lerp)
179
-
180
- # if normalize_velocity: velocity_ = velocity_ / velocity_.std().add_(1e-8)
181
- # return tensors_.mul_(velocity_.lazy_mul(1-dampening) if abs else velocity_.abs().lazy_mul_(1-dampening))
182
-
183
-
184
- # class MultiplicativeMomentum(Transform):
185
- # """sucks"""
186
- # def __init__(self, momentum: float = 0.9, dampening: float = 0,normalize_velocity: bool = True, abs: bool = False, lerp: bool = False):
187
- # defaults = dict(momentum=momentum, dampening=dampening, normalize_velocity=normalize_velocity,abs=abs, lerp=lerp)
188
- # super().__init__(defaults, uses_grad=False)
189
-
190
- # @torch.no_grad
191
- # def apply(self, tensors, params, grads, loss, states, settings):
192
- # momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
193
- # abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
194
- # velocity = self.get_state('velocity', params=params, cls=TensorList)
195
- # return multiplicative_momentum_(TensorList(target), velocity_=velocity, momentum=momentum, dampening=dampening,
196
- # normalize_velocity=normalize_velocity,abs=abs,lerp=lerp)
197
-
@@ -3,28 +3,36 @@ from typing import Any, Literal, overload
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, apply_transform, Modular
6
+ from ...core import Chainable, Modular, Module, apply_transform
7
7
  from ...utils import TensorList, as_tensorlist
8
- from ...utils.derivatives import hvp
8
+ from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
9
9
  from ..quasi_newton import LBFGS
10
10
 
11
+
11
12
  class NewtonSolver(Module):
12
- """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
13
+ """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
13
14
  def __init__(
14
15
  self,
15
16
  solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
16
17
  maxiter=None,
17
- tol=1e-3,
18
+ maxiter1=None,
19
+ tol:float | None=1e-3,
18
20
  reg: float = 0,
19
21
  warm_start=True,
22
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
23
+ reset_solver: bool = False,
24
+ h: float= 1e-3,
20
25
  inner: Chainable | None = None,
21
26
  ):
22
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
27
+ defaults = dict(tol=tol, h=h,reset_solver=reset_solver, maxiter=maxiter, maxiter1=maxiter1, reg=reg, warm_start=warm_start, solver=solver, hvp_method=hvp_method)
23
28
  super().__init__(defaults,)
24
29
 
25
30
  if inner is not None:
26
31
  self.set_child('inner', inner)
27
32
 
33
+ self._num_hvps = 0
34
+ self._num_hvps_last_step = 0
35
+
28
36
  @torch.no_grad
29
37
  def step(self, var):
30
38
  params = TensorList(var.params)
@@ -34,19 +42,49 @@ class NewtonSolver(Module):
34
42
  settings = self.settings[params[0]]
35
43
  solver_cls = settings['solver']
36
44
  maxiter = settings['maxiter']
45
+ maxiter1 = settings['maxiter1']
37
46
  tol = settings['tol']
38
47
  reg = settings['reg']
48
+ hvp_method = settings['hvp_method']
39
49
  warm_start = settings['warm_start']
50
+ h = settings['h']
51
+ reset_solver = settings['reset_solver']
40
52
 
53
+ self._num_hvps_last_step = 0
41
54
  # ---------------------- Hessian vector product function --------------------- #
42
- grad = var.get_grad(create_graph=True)
55
+ if hvp_method == 'autograd':
56
+ grad = var.get_grad(create_graph=True)
43
57
 
44
- def H_mm(x):
45
- with torch.enable_grad():
46
- Hvp = TensorList(hvp(params, grad, x, create_graph=True))
58
+ def H_mm(x):
59
+ self._num_hvps_last_step += 1
60
+ with torch.enable_grad():
61
+ Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
47
62
  if reg != 0: Hvp = Hvp + (x*reg)
48
63
  return Hvp
49
64
 
65
+ else:
66
+
67
+ with torch.enable_grad():
68
+ grad = var.get_grad()
69
+
70
+ if hvp_method == 'forward':
71
+ def H_mm(x):
72
+ self._num_hvps_last_step += 1
73
+ Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
74
+ if reg != 0: Hvp = Hvp + (x*reg)
75
+ return Hvp
76
+
77
+ elif hvp_method == 'central':
78
+ def H_mm(x):
79
+ self._num_hvps_last_step += 1
80
+ Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
81
+ if reg != 0: Hvp = Hvp + (x*reg)
82
+ return Hvp
83
+
84
+ else:
85
+ raise ValueError(hvp_method)
86
+
87
+
50
88
  # -------------------------------- inner step -------------------------------- #
51
89
  b = as_tensorlist(grad)
52
90
  if 'inner' in self.children:
@@ -58,23 +96,46 @@ class NewtonSolver(Module):
58
96
  if x0 is None: x = b.zeros_like().requires_grad_(True)
59
97
  else: x = x0.clone().requires_grad_(True)
60
98
 
61
- solver = solver_cls(x)
99
+
100
+ if 'solver' not in self.global_state:
101
+ if maxiter1 is not None: maxiter = maxiter1
102
+ solver = self.global_state['solver'] = solver_cls(x)
103
+ self.global_state['x'] = x
104
+
105
+ else:
106
+ if reset_solver:
107
+ solver = self.global_state['solver'] = solver_cls(x)
108
+ else:
109
+ solver_params = self.global_state['x']
110
+ solver_params.set_(x)
111
+ x = solver_params
112
+ solver = self.global_state['solver']
113
+
62
114
  def lstsq_closure(backward=True):
63
- Hx = H_mm(x)
64
- loss = (Hx-b).pow(2).global_mean()
115
+ Hx = H_mm(x).detach()
116
+ # loss = (Hx-b).pow(2).global_mean()
117
+ # if backward:
118
+ # solver.zero_grad()
119
+ # loss.backward(inputs=x)
120
+
121
+ residual = Hx - b
122
+ loss = residual.pow(2).global_mean()
65
123
  if backward:
66
- solver.zero_grad()
67
- loss.backward(inputs=x)
124
+ with torch.no_grad():
125
+ H_residual = H_mm(residual)
126
+ n = residual.global_numel()
127
+ x.set_grad_((2.0 / n) * H_residual)
128
+
68
129
  return loss
69
130
 
70
131
  if maxiter is None: maxiter = b.global_numel()
71
132
  loss = None
72
- initial_loss = lstsq_closure(False)
73
- if initial_loss > tol:
133
+ initial_loss = lstsq_closure(False) if tol is not None else None # skip unnecessary closure if tol is None
134
+ if initial_loss is None or initial_loss > torch.finfo(b[0].dtype).eps:
74
135
  for i in range(maxiter):
75
136
  loss = solver.step(lstsq_closure)
76
137
  assert loss is not None
77
- if min(loss, loss/initial_loss) < tol: break
138
+ if initial_loss is not None and loss/initial_loss < tol: break
78
139
 
79
140
  # print(f'{loss = }')
80
141
 
@@ -83,6 +144,7 @@ class NewtonSolver(Module):
83
144
  x0.copy_(x)
84
145
 
85
146
  var.update = x.detach()
147
+ self._num_hvps += self._num_hvps_last_step
86
148
  return var
87
149
 
88
150
 
@@ -10,16 +10,16 @@ import torch
10
10
  from ...core import Chainable, Module, apply_transform
11
11
  from ...utils import TensorList, vec_to_tensors
12
12
  from ...utils.derivatives import (
13
- hessian_list_to_mat,
13
+ flatten_jacobian,
14
14
  jacobian_wrt,
15
15
  )
16
16
  from ..second_order.newton import (
17
- cholesky_solve,
18
- eigh_solve,
19
- least_squares_solve,
20
- lu_solve,
17
+ _cholesky_solve,
18
+ _eigh_solve,
19
+ _least_squares_solve,
20
+ _lu_solve,
21
21
  )
22
-
22
+ from ...utils.linalg.linear_operator import Dense
23
23
 
24
24
  class NewtonNewton(Module):
25
25
  """Applies Newton-like preconditioning to Newton step.
@@ -51,10 +51,10 @@ class NewtonNewton(Module):
51
51
  super().__init__(defaults)
52
52
 
53
53
  @torch.no_grad
54
- def step(self, var):
54
+ def update(self, var):
55
55
  params = TensorList(var.params)
56
56
  closure = var.closure
57
- if closure is None: raise RuntimeError('NewtonCG requires closure')
57
+ if closure is None: raise RuntimeError('NewtonNewton requires closure')
58
58
 
59
59
  settings = self.settings[params[0]]
60
60
  reg = settings['reg']
@@ -64,6 +64,7 @@ class NewtonNewton(Module):
64
64
  eigval_tfm = settings['eigval_tfm']
65
65
 
66
66
  # ------------------------ calculate grad and hessian ------------------------ #
67
+ Hs = []
67
68
  with torch.enable_grad():
68
69
  loss = var.loss = var.loss_approx = closure(False)
69
70
  g_list = torch.autograd.grad(loss, params, create_graph=True)
@@ -76,17 +77,29 @@ class NewtonNewton(Module):
76
77
  is_last = o == order
77
78
  H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
78
79
  with torch.no_grad() if is_last else nullcontext():
79
- H = hessian_list_to_mat(H_list)
80
+ H = flatten_jacobian(H_list)
80
81
  if reg != 0: H = H + I * reg
82
+ Hs.append(H)
81
83
 
82
84
  x = None
83
85
  if search_negative or (is_last and eigval_tfm is not None):
84
- x = eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
85
- if x is None: x = cholesky_solve(H, xp)
86
- if x is None: x = lu_solve(H, xp)
87
- if x is None: x = least_squares_solve(H, xp)
86
+ x = _eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
87
+ if x is None: x = _cholesky_solve(H, xp)
88
+ if x is None: x = _lu_solve(H, xp)
89
+ if x is None: x = _least_squares_solve(H, xp)
88
90
  xp = x.squeeze()
89
91
 
90
- var.update = vec_to_tensors(xp.nan_to_num_(0,0,0), params)
92
+ self.global_state["Hs"] = Hs
93
+ self.global_state['xp'] = xp.nan_to_num_(0,0,0)
94
+
95
+ @torch.no_grad
96
+ def apply(self, var):
97
+ params = var.params
98
+ xp = self.global_state['xp']
99
+ var.update = vec_to_tensors(xp, params)
91
100
  return var
92
101
 
102
+ def get_H(self, var):
103
+ Hs = self.global_state["Hs"]
104
+ if len(Hs) == 1: return Dense(Hs[0])
105
+ return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
@@ -0,0 +1,105 @@
1
+ from typing import Literal, overload
2
+
3
+ import torch
4
+ from scipy.sparse.linalg import LinearOperator, gcrotmk
5
+
6
+ from ...core import Chainable, Module, apply_transform
7
+ from ...utils import NumberList, TensorList, as_tensorlist, generic_vector_norm, vec_to_tensors
8
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
9
+ from ...utils.linalg.solve import cg, minres
10
+
11
+
12
+ class ScipyNewtonCG(Module):
13
+ """NewtonCG with scipy solvers (any from scipy.sparse.linalg)"""
14
+ def __init__(
15
+ self,
16
+ solver = gcrotmk,
17
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
+ h: float = 1e-3,
19
+ warm_start=False,
20
+ inner: Chainable | None = None,
21
+ kwargs: dict | None = None,
22
+ ):
23
+ defaults = dict(hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
24
+ super().__init__(defaults,)
25
+
26
+ if inner is not None:
27
+ self.set_child('inner', inner)
28
+
29
+ self._num_hvps = 0
30
+ self._num_hvps_last_step = 0
31
+
32
+ if kwargs is None: kwargs = {}
33
+ self._kwargs = kwargs
34
+
35
+ @torch.no_grad
36
+ def step(self, var):
37
+ params = TensorList(var.params)
38
+ closure = var.closure
39
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
40
+
41
+ settings = self.settings[params[0]]
42
+ hvp_method = settings['hvp_method']
43
+ solver = settings['solver']
44
+ h = settings['h']
45
+ warm_start = settings['warm_start']
46
+
47
+ self._num_hvps_last_step = 0
48
+ # ---------------------- Hessian vector product function --------------------- #
49
+ device = params[0].device; dtype=params[0].dtype
50
+ if hvp_method == 'autograd':
51
+ grad = var.get_grad(create_graph=True)
52
+
53
+ def H_mm(x_np):
54
+ self._num_hvps_last_step += 1
55
+ x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
56
+ with torch.enable_grad():
57
+ Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
58
+ return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
59
+
60
+ else:
61
+
62
+ with torch.enable_grad():
63
+ grad = var.get_grad()
64
+
65
+ if hvp_method == 'forward':
66
+ def H_mm(x_np):
67
+ self._num_hvps_last_step += 1
68
+ x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
69
+ Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
70
+ return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
71
+
72
+ elif hvp_method == 'central':
73
+ def H_mm(x_np):
74
+ self._num_hvps_last_step += 1
75
+ x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
76
+ Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
77
+ return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
78
+
79
+ else:
80
+ raise ValueError(hvp_method)
81
+
82
+ ndim = sum(p.numel() for p in params)
83
+ H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
84
+
85
+ # -------------------------------- inner step -------------------------------- #
86
+ b = var.get_update()
87
+ if 'inner' in self.children:
88
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
89
+ b = as_tensorlist(b)
90
+
91
+ # ---------------------------------- run cg ---------------------------------- #
92
+ x0 = None
93
+ if warm_start: x0 = self.global_state.get('x_prev', None) # initialized to 0 which is default anyway
94
+
95
+ x_np = solver(H, b.to_vec().nan_to_num().numpy(force=True), x0=x0, **self._kwargs)
96
+ if isinstance(x_np, tuple): x_np = x_np[0]
97
+
98
+ if warm_start:
99
+ self.global_state['x_prev'] = x_np
100
+
101
+ var.update = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
102
+
103
+ self._num_hvps += self._num_hvps_last_step
104
+ return var
105
+
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...core import Chainable
7
7
  from ...utils import vec_to_tensors, TensorList
8
- from ..optimizers.shampoo import _merge_small_dims
8
+ from ..adaptive.shampoo import _merge_small_dims
9
9
  from ..projections import ProjectionBase
10
10
 
11
11