torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -0,0 +1,170 @@
1
+ # pylint:disable=not-callable
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module
7
+ from ...utils import TensorList, vec_to_tensors
8
+ from ...utils.linalg.linear_operator import LinearOperator
9
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
10
+
11
+
12
+ # code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
13
+ # ported to pytorch and linear operator
14
+ def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_params_plus_x_fn: Callable | None, it_max=100, epsilon=1e-8, ):
15
+ """
16
+ Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
17
+
18
+ For explanation of Cauchy point, see "Gradient Descent
19
+ Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
20
+ https://arxiv.org/pdf/1612.00547.pdf
21
+ Other potential implementations can be found in paper
22
+ "Adaptive cubic regularisation methods"
23
+ https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
24
+ """
25
+ solver_it = 1
26
+ newton_step = H.solve(g).neg_()
27
+ if M == 0:
28
+ return newton_step, solver_it
29
+
30
+ def cauchy_point(g, H:LinearOperator, M):
31
+ if torch.linalg.vector_norm(g) == 0 or M == 0:
32
+ return 0 * g
33
+ g_dir = g / torch.linalg.vector_norm(g)
34
+ H_g_g = H.matvec(g_dir) @ g_dir
35
+ R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
36
+ return -R * g_dir
37
+
38
+ def conv_criterion(s, r):
39
+ """
40
+ The convergence criterion is an increasing and concave function in r
41
+ and it is equal to 0 only if r is the solution to the cubic problem
42
+ """
43
+ s_norm = torch.linalg.vector_norm(s)
44
+ return 1/s_norm - 1/r
45
+
46
+ # Solution s satisfies ||s|| >= Cauchy_radius
47
+ r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
48
+
49
+ if (loss_at_params_plus_x_fn is not None) and (f > loss_at_params_plus_x_fn(newton_step)):
50
+ return newton_step, solver_it
51
+
52
+ r_max = torch.linalg.vector_norm(newton_step)
53
+ if r_max - r_min < epsilon:
54
+ return newton_step, solver_it
55
+
56
+ # id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
57
+ s_lam = None
58
+ for _ in range(it_max):
59
+ r_try = (r_min + r_max) / 2
60
+ lam = r_try * M
61
+ s_lam = H.add_diagonal(lam).solve(g).neg()
62
+ # s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
63
+ solver_it += 1
64
+ crit = conv_criterion(s_lam, r_try)
65
+ if torch.abs(crit) < epsilon:
66
+ return s_lam, solver_it
67
+ if crit < 0:
68
+ r_min = r_try
69
+ else:
70
+ r_max = r_try
71
+ if r_max - r_min < epsilon:
72
+ break
73
+ assert s_lam is not None
74
+ return s_lam, solver_it
75
+
76
+
77
+ class CubicRegularization(TrustRegionBase):
78
+ """Cubic regularization.
79
+
80
+ Args:
81
+ hess_module (Module | None, optional):
82
+ A module that maintains a hessian approximation (not hessian inverse!).
83
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
84
+ When using quasi-newton methods, set `inverse=False` when constructing them.
85
+ eta (float, optional):
86
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
87
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
88
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
89
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
90
+ rho_good (float, optional):
91
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
92
+ rho_bad (float, optional):
93
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
94
+ init (float, optional): Initial trust region value. Defaults to 1.
95
+ maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
96
+ eps (float, optional): epsilon for the solver, defaults to 1e-8.
97
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
98
+ max_attempts (max_attempts, optional):
99
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
100
+ this limit is exceeded. Defaults to 10.
101
+ fallback (bool, optional):
102
+ if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
103
+ be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
104
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
105
+
106
+
107
+ Examples:
108
+ Cubic regularized newton
109
+
110
+ .. code-block:: python
111
+
112
+ opt = tz.Modular(
113
+ model.parameters(),
114
+ tz.m.CubicRegularization(tz.m.Newton()),
115
+ )
116
+
117
+ """
118
+ def __init__(
119
+ self,
120
+ hess_module: Chainable,
121
+ eta: float= 0.0,
122
+ nplus: float = 3.5,
123
+ nminus: float = 0.25,
124
+ rho_good: float = 0.99,
125
+ rho_bad: float = 1e-4,
126
+ init: float = 1,
127
+ max_attempts: int = 10,
128
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
129
+ maxiter: int = 100,
130
+ eps: float = 1e-8,
131
+ check_decrease:bool=False,
132
+ update_freq: int = 1,
133
+ inner: Chainable | None = None,
134
+ ):
135
+ defaults = dict(maxiter=maxiter, eps=eps, check_decrease=check_decrease)
136
+ super().__init__(
137
+ defaults=defaults,
138
+ hess_module=hess_module,
139
+ eta=eta,
140
+ nplus=nplus,
141
+ nminus=nminus,
142
+ rho_good=rho_good,
143
+ rho_bad=rho_bad,
144
+ init=init,
145
+ max_attempts=max_attempts,
146
+ radius_strategy=radius_strategy,
147
+ update_freq=update_freq,
148
+ inner=inner,
149
+
150
+ boundary_tol=None,
151
+ radius_fn=None,
152
+ )
153
+
154
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
155
+ params = TensorList(params)
156
+
157
+ loss_at_params_plus_x_fn = None
158
+ if settings['check_decrease']:
159
+ def closure_plus_x(x):
160
+ x_unflat = vec_to_tensors(x, params)
161
+ params.add_(x_unflat)
162
+ loss_x = closure(False)
163
+ params.sub_(x_unflat)
164
+ return loss_x
165
+ loss_at_params_plus_x_fn = closure_plus_x
166
+
167
+
168
+ d, _ = ls_cubic_solver(f=f, g=g, H=H, M=1/radius, loss_at_params_plus_x_fn=loss_at_params_plus_x_fn,
169
+ it_max=settings['maxiter'], epsilon=settings['eps'])
170
+ return d.neg_()
@@ -0,0 +1,92 @@
1
+ # pylint:disable=not-callable
2
+ import torch
3
+
4
+ from ...core import Chainable, Module
5
+ from ...utils import TensorList, vec_to_tensors
6
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
7
+
8
+ class Dogleg(TrustRegionBase):
9
+ """Dogleg trust region algorithm.
10
+
11
+
12
+ Args:
13
+ hess_module (Module | None, optional):
14
+ A module that maintains a hessian approximation (not hessian inverse!).
15
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
16
+ When using quasi-newton methods, set `inverse=False` when constructing them.
17
+ eta (float, optional):
18
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
19
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
20
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
21
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
22
+ rho_good (float, optional):
23
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
24
+ rho_bad (float, optional):
25
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
26
+ init (float, optional): Initial trust region value. Defaults to 1.
27
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
28
+ max_attempts (max_attempts, optional):
29
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
30
+ this limit is exceeded. Defaults to 10.
31
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
32
+
33
+ """
34
+ def __init__(
35
+ self,
36
+ hess_module: Chainable,
37
+ eta: float= 0.0,
38
+ nplus: float = 2,
39
+ nminus: float = 0.25,
40
+ rho_good: float = 0.75,
41
+ rho_bad: float = 0.25,
42
+ boundary_tol: float | None = None,
43
+ init: float = 1,
44
+ max_attempts: int = 10,
45
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
46
+ update_freq: int = 1,
47
+ inner: Chainable | None = None,
48
+ ):
49
+ defaults = dict()
50
+ super().__init__(
51
+ defaults=defaults,
52
+ hess_module=hess_module,
53
+ eta=eta,
54
+ nplus=nplus,
55
+ nminus=nminus,
56
+ rho_good=rho_good,
57
+ rho_bad=rho_bad,
58
+ boundary_tol=boundary_tol,
59
+ init=init,
60
+ max_attempts=max_attempts,
61
+ radius_strategy=radius_strategy,
62
+ update_freq=update_freq,
63
+ inner=inner,
64
+
65
+ radius_fn=torch.linalg.vector_norm,
66
+ )
67
+
68
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
69
+ if radius > 2: radius = self.global_state['radius'] = 2
70
+ eps = torch.finfo(g.dtype).tiny * 2
71
+
72
+ gHg = g.dot(H.matvec(g))
73
+ if gHg <= eps:
74
+ return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
75
+
76
+ p_cauchy = (g.dot(g) / gHg) * g
77
+ p_newton = H.solve(g)
78
+
79
+ a = p_newton - p_cauchy
80
+ b = p_cauchy
81
+
82
+ aa = a.dot(a)
83
+ if aa < eps:
84
+ return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
85
+
86
+ ab = a.dot(b)
87
+ bb = b.dot(b)
88
+ c = bb - radius**2
89
+ discriminant = (2*ab)**2 - 4*aa*c
90
+ beta = (-2*ab + torch.sqrt(discriminant.clip(min=0))) / (2 * aa)
91
+ return p_cauchy + beta * (p_newton - p_cauchy)
92
+
@@ -0,0 +1,128 @@
1
+ # pylint:disable=not-callable
2
+ import torch
3
+
4
+ from ...core import Chainable, Module
5
+ from ...utils.linalg import linear_operator
6
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
7
+
8
+
9
+ class LevenbergMarquardt(TrustRegionBase):
10
+ """Levenberg-Marquardt trust region algorithm.
11
+
12
+
13
+ Args:
14
+ hess_module (Module | None, optional):
15
+ A module that maintains a hessian approximation (not hessian inverse!).
16
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
17
+ When using quasi-newton methods, set ``inverse=False`` when constructing them.
18
+ y (float, optional):
19
+ when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
20
+ is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
21
+ eta (float, optional):
22
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
23
+ When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
24
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
25
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
26
+ rho_good (float, optional):
27
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
28
+ rho_bad (float, optional):
29
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
30
+ init (float, optional): Initial trust region value. Defaults to 1.
31
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
32
+ max_attempts (max_attempts, optional):
33
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
34
+ this limit is exceeded. Defaults to 10.
35
+ fallback (bool, optional):
36
+ if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
37
+ be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
38
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
39
+
40
+ Examples:
41
+ Gauss-Newton with Levenberg-Marquardt trust-region
42
+
43
+ .. code-block:: python
44
+
45
+ opt = tz.Modular(
46
+ model.parameters(),
47
+ tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
48
+ )
49
+
50
+ LM-SR1
51
+
52
+ .. code-block:: python
53
+
54
+ opt = tz.Modular(
55
+ model.parameters(),
56
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
57
+ )
58
+
59
+ First order trust region (hessian is assumed to be identity)
60
+
61
+ .. code-block:: python
62
+
63
+ opt = tz.Modular(
64
+ model.parameters(),
65
+ tz.m.LevenbergMarquardt(tz.m.Identity()),
66
+ )
67
+
68
+ """
69
+ def __init__(
70
+ self,
71
+ hess_module: Chainable,
72
+ eta: float= 0.0,
73
+ nplus: float = 3.5,
74
+ nminus: float = 0.25,
75
+ rho_good: float = 0.99,
76
+ rho_bad: float = 1e-4,
77
+ init: float = 1,
78
+ max_attempts: int = 10,
79
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
80
+ y: float = 0,
81
+ fallback: bool = False,
82
+ update_freq: int = 1,
83
+ inner: Chainable | None = None,
84
+ ):
85
+ defaults = dict(y=y, fallback=fallback)
86
+ super().__init__(
87
+ defaults=defaults,
88
+ hess_module=hess_module,
89
+ eta=eta,
90
+ nplus=nplus,
91
+ nminus=nminus,
92
+ rho_good=rho_good,
93
+ rho_bad=rho_bad,
94
+ init=init,
95
+ max_attempts=max_attempts,
96
+ radius_strategy=radius_strategy,
97
+ update_freq=update_freq,
98
+ inner=inner,
99
+
100
+ boundary_tol=None,
101
+ radius_fn=None,
102
+ )
103
+
104
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
105
+ y = settings['y']
106
+
107
+ if isinstance(H, linear_operator.DenseInverse):
108
+ if settings['fallback']:
109
+ H = H.to_dense()
110
+ else:
111
+ raise RuntimeError(
112
+ f"{self.children['hess_module']} maintains a hessian inverse. "
113
+ "LevenbergMarquardt requires the hessian, not the inverse. "
114
+ "If that module is a quasi-newton module, pass `inverse=False` on initialization. "
115
+ "Or pass `fallback=True` to LevenbergMarquardt to allow inverting the hessian inverse, "
116
+ "however that can be inefficient and unstable."
117
+ )
118
+
119
+ reg = 1/radius
120
+ if y == 0:
121
+ return H.add_diagonal(reg).solve(g)
122
+
123
+ diag = H.diagonal()
124
+ diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
125
+ if y != 1: diag = (diag*y) + (1-y)
126
+ return H.add_diagonal(diag*reg).solve(g)
127
+
128
+
@@ -0,0 +1,99 @@
1
+ import torch
2
+
3
+ from ...core import Chainable, Module
4
+ from ...utils.linalg import cg, linear_operator
5
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
6
+
7
+
8
+ class TrustCG(TrustRegionBase):
9
+ """Trust region via Steihaug-Toint Conjugate Gradient method.
10
+
11
+ .. note::
12
+
13
+ If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
14
+ which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
15
+ is possible, it is usually less efficient.
16
+
17
+ Args:
18
+ hess_module (Module | None, optional):
19
+ A module that maintains a hessian approximation (not hessian inverse!).
20
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
21
+ When using quasi-newton methods, set `inverse=False` when constructing them.
22
+ eta (float, optional):
23
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
24
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
25
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
26
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
27
+ rho_good (float, optional):
28
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
29
+ rho_bad (float, optional):
30
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
31
+ init (float, optional): Initial trust region value. Defaults to 1.
32
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
33
+ reg (int, optional): regularization parameter for conjugate gradient. Defaults to 0.
34
+ max_attempts (max_attempts, optional):
35
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
36
+ this limit is exceeded. Defaults to 10.
37
+ boundary_tol (float | None, optional):
38
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
39
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
40
+ prefer_exact (bool, optional):
41
+ when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity),
42
+ uses the exact solution. If False, always uses CG. Defaults to True.
43
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
44
+
45
+ Examples:
46
+ Trust-SR1
47
+
48
+ .. code-block:: python
49
+
50
+ opt = tz.Modular(
51
+ model.parameters(),
52
+ tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
53
+ )
54
+ """
55
+ def __init__(
56
+ self,
57
+ hess_module: Chainable,
58
+ eta: float= 0.0,
59
+ nplus: float = 3.5,
60
+ nminus: float = 0.25,
61
+ rho_good: float = 0.99,
62
+ rho_bad: float = 1e-4,
63
+ boundary_tol: float | None = 1e-6, # tuned
64
+ init: float = 1,
65
+ max_attempts: int = 10,
66
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
67
+ reg: float = 0,
68
+ maxiter: int | None = None,
69
+ miniter: int = 1,
70
+ cg_tol: float = 1e-8,
71
+ prefer_exact: bool = True,
72
+ update_freq: int = 1,
73
+ inner: Chainable | None = None,
74
+ ):
75
+ defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
76
+ super().__init__(
77
+ defaults=defaults,
78
+ hess_module=hess_module,
79
+ eta=eta,
80
+ nplus=nplus,
81
+ nminus=nminus,
82
+ rho_good=rho_good,
83
+ rho_bad=rho_bad,
84
+ boundary_tol=boundary_tol,
85
+ init=init,
86
+ max_attempts=max_attempts,
87
+ radius_strategy=radius_strategy,
88
+ update_freq=update_freq,
89
+ inner=inner,
90
+
91
+ radius_fn=torch.linalg.vector_norm,
92
+ )
93
+
94
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
95
+ if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
96
+ return H.solve_bounded(g, radius)
97
+
98
+ x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
99
+ return x