torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,128 @@
1
+ # pylint:disable=not-callable
2
+ import torch
3
+
4
+ from ...core import Chainable, Module
5
+ from ...utils.linalg import linear_operator
6
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
7
+
8
+
9
+ class LevenbergMarquardt(TrustRegionBase):
10
+ """Levenberg-Marquardt trust region algorithm.
11
+
12
+
13
+ Args:
14
+ hess_module (Module | None, optional):
15
+ A module that maintains a hessian approximation (not hessian inverse!).
16
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
17
+ When using quasi-newton methods, set `inverse=False` when constructing them.
18
+ y (float, optional):
19
+ when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
20
+ is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
21
+ eta (float, optional):
22
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
23
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
24
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
25
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
26
+ rho_good (float, optional):
27
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
28
+ rho_bad (float, optional):
29
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
30
+ init (float, optional): Initial trust region value. Defaults to 1.
31
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
32
+ max_attempts (max_attempts, optional):
33
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
34
+ this limit is exceeded. Defaults to 10.
35
+ fallback (bool, optional):
36
+ if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
37
+ be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
38
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
39
+
40
+ Examples:
41
+ Gauss-Newton with Levenberg-Marquardt trust-region
42
+
43
+ .. code-block:: python
44
+
45
+ opt = tz.Modular(
46
+ model.parameters(),
47
+ tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
48
+ )
49
+
50
+ LM-SR1
51
+
52
+ .. code-block:: python
53
+
54
+ opt = tz.Modular(
55
+ model.parameters(),
56
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
57
+ )
58
+
59
+ First order trust region (hessian is assumed to be identity)
60
+
61
+ .. code-block:: python
62
+
63
+ opt = tz.Modular(
64
+ model.parameters(),
65
+ tz.m.LevenbergMarquardt(tz.m.Identity()),
66
+ )
67
+
68
+ """
69
+ def __init__(
70
+ self,
71
+ hess_module: Chainable,
72
+ eta: float= 0.0,
73
+ nplus: float = 3.5,
74
+ nminus: float = 0.25,
75
+ rho_good: float = 0.99,
76
+ rho_bad: float = 1e-4,
77
+ init: float = 1,
78
+ max_attempts: int = 10,
79
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
80
+ y: float = 0,
81
+ fallback: bool = False,
82
+ update_freq: int = 1,
83
+ inner: Chainable | None = None,
84
+ ):
85
+ defaults = dict(y=y, fallback=fallback)
86
+ super().__init__(
87
+ defaults=defaults,
88
+ hess_module=hess_module,
89
+ eta=eta,
90
+ nplus=nplus,
91
+ nminus=nminus,
92
+ rho_good=rho_good,
93
+ rho_bad=rho_bad,
94
+ init=init,
95
+ max_attempts=max_attempts,
96
+ radius_strategy=radius_strategy,
97
+ update_freq=update_freq,
98
+ inner=inner,
99
+
100
+ boundary_tol=None,
101
+ radius_fn=None,
102
+ )
103
+
104
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
105
+ y = settings['y']
106
+
107
+ if isinstance(H, linear_operator.DenseInverse):
108
+ if settings['fallback']:
109
+ H = H.to_dense()
110
+ else:
111
+ raise RuntimeError(
112
+ f"{self.children['hess_module']} maintains a hessian inverse. "
113
+ "LevenbergMarquardt requires the hessian, not the inverse. "
114
+ "If that module is a quasi-newton module, pass `inverse=False` on initialization. "
115
+ "Or pass `fallback=True` to LevenbergMarquardt to allow inverting the hessian inverse, "
116
+ "however that can be inefficient and unstable."
117
+ )
118
+
119
+ reg = 1/radius
120
+ if y == 0:
121
+ return H.add_diagonal(reg).solve(g)
122
+
123
+ diag = H.diagonal()
124
+ diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
125
+ if y != 1: diag = (diag*y) + (1-y)
126
+ return H.add_diagonal(diag*reg).solve(g)
127
+
128
+
@@ -0,0 +1,97 @@
1
+ import torch
2
+
3
+ from ...core import Chainable, Module
4
+ from ...utils.linalg import cg, linear_operator
5
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
6
+
7
+
8
+ class TrustCG(TrustRegionBase):
9
+ """Trust region via Steihaug-Toint Conjugate Gradient method.
10
+
11
+ .. note::
12
+
13
+ If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
14
+ which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
15
+ is possible, it is usually less efficient.
16
+
17
+ Args:
18
+ hess_module (Module | None, optional):
19
+ A module that maintains a hessian approximation (not hessian inverse!).
20
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
21
+ When using quasi-newton methods, set `inverse=False` when constructing them.
22
+ eta (float, optional):
23
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
24
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
25
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
26
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
27
+ rho_good (float, optional):
28
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
29
+ rho_bad (float, optional):
30
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
31
+ init (float, optional): Initial trust region value. Defaults to 1.
32
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
33
+ reg (int, optional): regularization parameter for conjugate gradient. Defaults to 0.
34
+ max_attempts (max_attempts, optional):
35
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
36
+ this limit is exceeded. Defaults to 10.
37
+ boundary_tol (float | None, optional):
38
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
39
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
40
+ prefer_exact (bool, optional):
41
+ when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity),
42
+ uses the exact solution. If False, always uses CG. Defaults to True.
43
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
44
+
45
+ Examples:
46
+ Trust-SR1
47
+
48
+ .. code-block:: python
49
+
50
+ opt = tz.Modular(
51
+ model.parameters(),
52
+ tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
53
+ )
54
+ """
55
+ def __init__(
56
+ self,
57
+ hess_module: Chainable,
58
+ eta: float= 0.0,
59
+ nplus: float = 3.5,
60
+ nminus: float = 0.25,
61
+ rho_good: float = 0.99,
62
+ rho_bad: float = 1e-4,
63
+ boundary_tol: float | None = 1e-1,
64
+ init: float = 1,
65
+ max_attempts: int = 10,
66
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
67
+ reg: float = 0,
68
+ cg_tol: float = 1e-4,
69
+ prefer_exact: bool = True,
70
+ update_freq: int = 1,
71
+ inner: Chainable | None = None,
72
+ ):
73
+ defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol)
74
+ super().__init__(
75
+ defaults=defaults,
76
+ hess_module=hess_module,
77
+ eta=eta,
78
+ nplus=nplus,
79
+ nminus=nminus,
80
+ rho_good=rho_good,
81
+ rho_bad=rho_bad,
82
+ boundary_tol=boundary_tol,
83
+ init=init,
84
+ max_attempts=max_attempts,
85
+ radius_strategy=radius_strategy,
86
+ update_freq=update_freq,
87
+ inner=inner,
88
+
89
+ radius_fn=torch.linalg.vector_norm,
90
+ )
91
+
92
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
93
+ if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
94
+ return H.solve_bounded(g, radius)
95
+
96
+ x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], tol=settings["cg_tol"])
97
+ return x
@@ -0,0 +1,350 @@
1
+ import math
2
+ import warnings
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Callable, Mapping, Sequence
5
+ from functools import partial
6
+ from typing import Any, Literal, Protocol, cast, final, overload
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Module, Var, apply_transform
11
+ from ...utils import TensorList, safe_dict_update_, tofloat, vec_to_tensors, generic_finfo, generic_vector_norm
12
+ from ...utils.linalg.linear_operator import LinearOperator
13
+
14
+
15
+ def _flatten_tensors(tensors: list[torch.Tensor]):
16
+ return torch.cat([t.ravel() for t in tensors])
17
+
18
+
19
+
20
+ class _RadiusStrategy(Protocol):
21
+ def __call__(
22
+ self,
23
+ params: Sequence[torch.Tensor],
24
+ closure: Callable,
25
+ f: float,
26
+ g: torch.Tensor,
27
+ H: LinearOperator,
28
+ d: torch.Tensor,
29
+ trust_radius: float,
30
+ eta: float, # 0.0
31
+ nplus: float, # 3.5
32
+ nminus: float, # 0.25
33
+ rho_good: float, # 0.99
34
+ rho_bad: float, # 1e-4
35
+ boundary_tol: float | None,
36
+ init: float,
37
+ state: Mapping[str, Any],
38
+ settings: Mapping[str, Any],
39
+ radius_fn: Callable | None = torch.linalg.vector_norm,
40
+ ) -> tuple[float, bool]:
41
+ """returns (new trust_region value, success).
42
+
43
+ Args:
44
+ params (Sequence[torch.Tensor]): params tensor list
45
+ closure (Callable): closure
46
+ d (torch.Tensor):
47
+ current update vector with current trust_region, which is SUBTRACTED from parameters.
48
+ May be exact solution to (B+yI)x=g, approximate, or a solution to a different subproblem
49
+ (e.g. cubic regularization).
50
+ f (float | torch.Tensor): loss at x0
51
+ g (torch.Tensor): gradient vector
52
+ H (LinearOperator | None): hessian approximation
53
+ trust_radius (float): current trust region value
54
+ eta (float, optional):
55
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
56
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
57
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
58
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
59
+ rho_good (float, optional):
60
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
61
+ rho_bad (float, optional):
62
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
63
+ boundary_tol (float | None, optional):
64
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
65
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
66
+ init (float, optional): Initial trust region value. Defaults to 1.
67
+ state (dict, optional): global state of the module for storing persistent info.
68
+ settings (dict, optional): all settings in case this strategy has other settings.
69
+ radius_fn (Callable | None, optional):
70
+ function that accepts ``(d: torch.Tensor)`` and returns the actual region of ``d``
71
+ (e.g. L2) norm for L2 trust region.
72
+ """
73
+ ... # pylint:disable=unnecessary-ellipsis
74
+
75
+ def _get_rho(params: Sequence[torch.Tensor], closure:Callable,
76
+ f: float, g: torch.Tensor, H: LinearOperator, d:torch.Tensor, ):
77
+ """rho is reduction/pred_reduction"""
78
+
79
+ # evaluate actual loss reduction
80
+ update_unflattned = vec_to_tensors(d, params)
81
+ params = TensorList(params)
82
+ x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
83
+
84
+ params -= update_unflattned
85
+ f_star = closure(False)
86
+ params.set_(x0)
87
+
88
+ reduction = f - f_star
89
+
90
+ # expected reduction is g.T @ p + 0.5 * p.T @ B @ p
91
+ Hu = H.matvec(d)
92
+ pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
93
+
94
+ rho = reduction / (pred_reduction.clip(min=torch.finfo(g.dtype).tiny * 2))
95
+ return rho, f_star, reduction, pred_reduction
96
+
97
+ def _get_rho_tensorlist(params: Sequence[torch.Tensor], closure:Callable,
98
+ f: float, g: TensorList, Hvp: Callable[[TensorList], TensorList], d:TensorList):
99
+ """rho is reduction/pred_reduction"""
100
+ params = TensorList(params)
101
+ x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
102
+
103
+ # evaluate before modifying params to not break autograd
104
+ Hu = Hvp(d)
105
+
106
+ # actual f
107
+ params -= d
108
+ f_star = closure(False)
109
+ params.copy_(x0)
110
+
111
+ reduction = f - f_star
112
+
113
+ # expected f is g.T @ p + 0.5 * p.T @ B @ p
114
+ pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
115
+
116
+ rho = reduction / (pred_reduction.clip(min=torch.finfo(g[0].dtype).tiny * 2))
117
+ return rho, f_star, reduction, pred_reduction
118
+
119
+ @torch.no_grad
120
+ def default_radius(
121
+ params: Sequence[torch.Tensor],
122
+ closure: Callable,
123
+ f: float,
124
+ g: torch.Tensor | TensorList,
125
+ H: LinearOperator | Callable,
126
+ d: torch.Tensor | TensorList,
127
+ trust_radius: float,
128
+ eta: float, # 0.0
129
+ nplus: float, # 3.5
130
+ nminus: float, # 0.25
131
+ rho_good: float, # 0.99
132
+ rho_bad: float, # 1e-4
133
+ boundary_tol: float | None,
134
+ init: float,
135
+ state: Mapping[str, Any],
136
+ settings: Mapping[str, Any],
137
+ radius_fn: Callable | None = generic_vector_norm,
138
+ check_overflow: bool = True,
139
+ # dynamic_nminus: bool=False,
140
+ ) -> tuple[float, bool]:
141
+
142
+ # when rho_bad < rho < eta, no update is made but trust region is not updated.
143
+ if eta > rho_bad:
144
+ warnings.warn(f"trust region eta={eta} is larger than rho_bad={rho_bad}, "
145
+ "this can lead to trust region getting stuck.")
146
+
147
+ if isinstance(g, torch.Tensor):
148
+ rho, f_star, _, _ = _get_rho(params=params, closure=closure, f=f, g=g, H=H, d=d) # pyright:ignore[reportArgumentType]
149
+ else:
150
+ rho, f_star, _, _ = _get_rho_tensorlist(params=params, closure=closure, f=f, g=g, Hvp=H, d=d) # pyright:ignore[reportArgumentType]
151
+
152
+ is_finite = math.isfinite(f_star)
153
+
154
+ # find boundary of current step
155
+ if radius_fn is None: d_radius = trust_radius
156
+ else: d_radius = radius_fn(d)
157
+
158
+ # failed step
159
+ if rho < rho_bad or not is_finite:
160
+ # if dynamic_nminus and rho > 0: nminus = nminus * max(rho, 1e-4)
161
+ trust_radius = d_radius*nminus
162
+
163
+ # very good step
164
+ elif rho > rho_good and is_finite:
165
+ if (boundary_tol is None) or (trust_radius-d_radius)/trust_radius < boundary_tol:
166
+ trust_radius = max(trust_radius, d_radius*nplus)
167
+
168
+ # prevent very small or large values
169
+ if check_overflow:
170
+ finfo = generic_finfo(g)
171
+ if trust_radius < finfo.tiny*2 or trust_radius > finfo.max/2:
172
+ trust_radius = init
173
+
174
+ # return new trust region and success boolean
175
+ return tofloat(trust_radius), rho > eta and is_finite
176
+
177
+
178
+ def fixed_radius(
179
+ params: Sequence[torch.Tensor],
180
+ closure: Callable,
181
+ f: float,
182
+ g: torch.Tensor,
183
+ H: LinearOperator,
184
+ d: torch.Tensor,
185
+ trust_radius: float,
186
+ eta: float, # 0.0
187
+ nplus: float, # 3.5
188
+ nminus: float, # 0.25
189
+ rho_good: float, # 0.99
190
+ rho_bad: float, # 1e-4
191
+ boundary_tol: float | None,
192
+ init: float,
193
+ state: Mapping[str, Any],
194
+ settings: Mapping[str, Any],
195
+ radius_fn: Callable | None = torch.linalg.vector_norm,
196
+ ) -> tuple[float, bool]:
197
+ return init, True
198
+
199
+ _RADIUS_KEYS = Literal['default', 'fixed']
200
+ _RADIUS_STRATEGIES: dict[_RADIUS_KEYS, _RadiusStrategy] = {
201
+ "default": default_radius,
202
+ "fixed": fixed_radius,
203
+ # "dynamic": partial(default_radius, dynamic_nminus=True)
204
+ }
205
+
206
+ class TrustRegionBase(Module, ABC):
207
+ def __init__(
208
+ self,
209
+ defaults: dict | None,
210
+ hess_module: Chainable,
211
+ # suggested default values:
212
+ # Gould, Nicholas IM, et al. "Sensitivity of trust-region algorithms to their parameters." 4OR 3.3 (2005): 227-241.
213
+ # which I found from https://github.com/patrick-kidger/optimistix/blob/c1dad7e75fc35bd5a4977ac3a872991e51e83d2c/optimistix/_solver/trust_region.py#L113-200
214
+ eta: float, # 0.0
215
+ nplus: float, # 3.5
216
+ nminus: float, # 0.25
217
+ rho_good: float, # 0.99
218
+ rho_bad: float, # 1e-4
219
+ boundary_tol: float | None, # None or 1e-1
220
+ init: float, # 1
221
+ max_attempts: int, # 10
222
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS, # "default"
223
+ radius_fn: Callable | None, # torch.linalg.vector_norm
224
+ update_freq: int = 1,
225
+ inner: Chainable | None = None,
226
+ ):
227
+ if isinstance(radius_strategy, str): radius_strategy = _RADIUS_STRATEGIES[radius_strategy]
228
+ if defaults is None: defaults = {}
229
+
230
+ safe_dict_update_(
231
+ defaults,
232
+ dict(eta=eta, nplus=nplus, nminus=nminus, rho_good=rho_good, rho_bad=rho_bad, init=init,
233
+ update_freq=update_freq, max_attempts=max_attempts, radius_strategy=radius_strategy,
234
+ boundary_tol=boundary_tol)
235
+ )
236
+
237
+ super().__init__(defaults)
238
+
239
+ self._radius_fn = radius_fn
240
+ self.set_child('hess_module', hess_module)
241
+
242
+ if inner is not None:
243
+ self.set_child('inner', inner)
244
+
245
+ @abstractmethod
246
+ def trust_solve(
247
+ self,
248
+ f: float,
249
+ g: torch.Tensor,
250
+ H: LinearOperator,
251
+ radius: float,
252
+ params: list[torch.Tensor],
253
+ closure: Callable,
254
+ settings: Mapping[str, Any],
255
+ ) -> torch.Tensor:
256
+ """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
257
+ ... # pylint:disable=unnecessary-ellipsis
258
+
259
+ def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
260
+ """updates the state of this module after H or B have been updated, if necessary"""
261
+
262
+ def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
263
+ """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
264
+ assert H is not None
265
+
266
+ params = TensorList(var.params)
267
+ settings = self.settings[params[0]]
268
+ g = _flatten_tensors(tensors)
269
+
270
+ max_attempts = settings['max_attempts']
271
+
272
+ # loss at x_0
273
+ loss = var.loss
274
+ closure = var.closure
275
+ if closure is None: raise RuntimeError("Trust region requires closure")
276
+ if loss is None: loss = var.get_loss(False)
277
+ loss = tofloat(loss)
278
+
279
+ # trust region step and update
280
+ success = False
281
+ d = None
282
+ while not success:
283
+ max_attempts -= 1
284
+ if max_attempts < 0: break
285
+
286
+ trust_radius = self.global_state.get('trust_radius', settings['init'])
287
+
288
+ # solve Hx=g
289
+ d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)
290
+
291
+ # update trust radius
292
+ radius_strategy: _RadiusStrategy = settings['radius_strategy']
293
+ self.global_state["trust_radius"], success = radius_strategy(
294
+ params=params,
295
+ closure=closure,
296
+ d=d,
297
+ f=loss,
298
+ g=g,
299
+ H=H,
300
+ trust_radius=trust_radius,
301
+
302
+ eta=settings["eta"],
303
+ nplus=settings["nplus"],
304
+ nminus=settings["nminus"],
305
+ rho_good=settings["rho_good"],
306
+ rho_bad=settings["rho_bad"],
307
+ boundary_tol=settings["boundary_tol"],
308
+ init=settings["init"],
309
+
310
+ state=self.global_state,
311
+ settings=settings,
312
+ radius_fn=self._radius_fn,
313
+ )
314
+
315
+ assert d is not None
316
+ if success: var.update = vec_to_tensors(d, params)
317
+ else: var.update = params.zeros_like()
318
+
319
+ return var
320
+
321
+
322
+ @final
323
+ @torch.no_grad
324
+ def update(self, var):
325
+ step = self.global_state.get('step', 0)
326
+ self.global_state['step'] = step + 1
327
+
328
+ if step % self.defaults["update_freq"] == 0:
329
+
330
+ hessian_module = self.children['hess_module']
331
+ hessian_module.update(var)
332
+ H = hessian_module.get_H(var)
333
+ self.global_state["H"] = H
334
+
335
+ self.trust_region_update(var, H=H)
336
+
337
+
338
+ @final
339
+ @torch.no_grad
340
+ def apply(self, var):
341
+ H = self.global_state.get('H', None)
342
+
343
+ # -------------------------------- inner step -------------------------------- #
344
+ update = var.get_update()
345
+ if 'inner' in self.children:
346
+ update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
347
+
348
+ # ----------------------------------- apply ---------------------------------- #
349
+ return self.trust_region_apply(var=var, tensors=update, H=H)
350
+
@@ -0,0 +1 @@
1
+ from .svrg import SVRG