torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,369 +1,355 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Literal
3
-
4
- import torch
5
-
6
- from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
7
- from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
8
- from .quasi_newton import _safe_clip, HessianUpdateStrategy
9
-
10
-
11
- class ConguateGradientBase(Transform, ABC):
12
- """Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
13
-
14
- This is an abstract class, to use it, subclass it and override `get_beta`.
15
-
16
-
17
- Args:
18
- defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
19
- clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
20
- reset_interval (int | None | Literal["auto"], optional):
21
- interval between resetting the search direction.
22
- "auto" means number of dimensions + 1, None means no reset. Defaults to None.
23
- inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
24
-
25
- Example:
26
-
27
- .. code-block:: python
28
-
29
- class PolakRibiere(ConguateGradientBase):
30
- def __init__(
31
- self,
32
- clip_beta=True,
33
- reset_interval: int | None = None,
34
- inner: Chainable | None = None
35
- ):
36
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
37
-
38
- def get_beta(self, p, g, prev_g, prev_d):
39
- denom = prev_g.dot(prev_g)
40
- if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
41
- return g.dot(g - prev_g) / denom
42
-
43
- """
44
- def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
45
- if defaults is None: defaults = {}
46
- defaults['reset_interval'] = reset_interval
47
- defaults['clip_beta'] = clip_beta
48
- super().__init__(defaults, uses_grad=False)
49
-
50
- if inner is not None:
51
- self.set_child('inner', inner)
52
-
53
- def reset(self):
54
- super().reset()
55
-
56
- def reset_for_online(self):
57
- super().reset_for_online()
58
- self.clear_state_keys('prev_grad')
59
- self.global_state.pop('stage', None)
60
- self.global_state['step'] = self.global_state.get('step', 1) - 1
61
-
62
- def initialize(self, p: TensorList, g: TensorList):
63
- """runs on first step when prev_grads and prev_dir are not available"""
64
-
65
- @abstractmethod
66
- def get_beta(self, p: TensorList, g: TensorList, prev_g: TensorList, prev_d: TensorList) -> float | torch.Tensor:
67
- """returns beta"""
68
-
69
- @torch.no_grad
70
- def update_tensors(self, tensors, params, grads, loss, states, settings):
71
- tensors = as_tensorlist(tensors)
72
- params = as_tensorlist(params)
73
-
74
- step = self.global_state.get('step', 0) + 1
75
- self.global_state['step'] = step
76
-
77
- # initialize on first step
78
- if self.global_state.get('stage', 0) == 0:
79
- g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
80
- d_prev.copy_(tensors)
81
- g_prev.copy_(tensors)
82
- self.initialize(params, tensors)
83
- self.global_state['stage'] = 1
84
-
85
- else:
86
- # if `update_tensors` was called multiple times before `apply_tensors`,
87
- # stage becomes 2
88
- self.global_state['stage'] = 2
89
-
90
- @torch.no_grad
91
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
92
- tensors = as_tensorlist(tensors)
93
- step = self.global_state['step']
94
-
95
- if 'inner' in self.children:
96
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
97
-
98
- assert self.global_state['stage'] != 0
99
- if self.global_state['stage'] == 1:
100
- self.global_state['stage'] = 2
101
- return tensors
102
-
103
- params = as_tensorlist(params)
104
- g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
105
-
106
- # get beta
107
- beta = self.get_beta(params, tensors, g_prev, d_prev)
108
- if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
109
-
110
- # inner step
111
- # calculate new direction with beta
112
- dir = tensors.add_(d_prev.mul_(beta))
113
- d_prev.copy_(dir)
114
-
115
- # resetting
116
- reset_interval = settings[0]['reset_interval']
117
- if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
118
- if reset_interval is not None and step % reset_interval == 0:
119
- self.reset()
120
-
121
- return dir
122
-
123
- # ------------------------------- Polak-Ribière ------------------------------ #
124
- def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
125
- denom = prev_g.dot(prev_g)
126
- if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
127
- return g.dot(g - prev_g) / denom
128
-
129
- class PolakRibiere(ConguateGradientBase):
130
- """Polak-Ribière-Polyak nonlinear conjugate gradient method.
131
-
132
- .. note::
133
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
134
- """
135
- def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
136
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
137
-
138
- def get_beta(self, p, g, prev_g, prev_d):
139
- return polak_ribiere_beta(g, prev_g)
140
-
141
- # ------------------------------ Fletcher–Reeves ----------------------------- #
142
- def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
143
- if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
144
- return gg / prev_gg
145
-
146
- class FletcherReeves(ConguateGradientBase):
147
- """Fletcher–Reeves nonlinear conjugate gradient method.
148
-
149
- .. note::
150
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
151
- """
152
- def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
153
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
154
-
155
- def initialize(self, p, g):
156
- self.global_state['prev_gg'] = g.dot(g)
157
-
158
- def get_beta(self, p, g, prev_g, prev_d):
159
- gg = g.dot(g)
160
- beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
161
- self.global_state['prev_gg'] = gg
162
- return beta
163
-
164
- # ----------------------------- Hestenes–Stiefel ----------------------------- #
165
- def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
166
- grad_diff = g - prev_g
167
- denom = prev_d.dot(grad_diff)
168
- if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
169
- return (g.dot(grad_diff) / denom).neg()
170
-
171
-
172
- class HestenesStiefel(ConguateGradientBase):
173
- """Hestenes–Stiefel nonlinear conjugate gradient method.
174
-
175
- .. note::
176
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
177
- """
178
- def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
179
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
180
-
181
- def get_beta(self, p, g, prev_g, prev_d):
182
- return hestenes_stiefel_beta(g, prev_d, prev_g)
183
-
184
-
185
- # --------------------------------- Dai–Yuan --------------------------------- #
186
- def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
187
- denom = prev_d.dot(g - prev_g)
188
- if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
189
- return (g.dot(g) / denom).neg()
190
-
191
- class DaiYuan(ConguateGradientBase):
192
- """Dai–Yuan nonlinear conjugate gradient method.
193
-
194
- .. note::
195
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this. Although Dai–Yuan formula provides an automatic step size scaling so it is technically possible to omit line search and instead use a small step size.
196
- """
197
- def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
198
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
199
-
200
- def get_beta(self, p, g, prev_g, prev_d):
201
- return dai_yuan_beta(g, prev_d, prev_g)
202
-
203
-
204
- # -------------------------------- Liu-Storey -------------------------------- #
205
- def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
206
- denom = prev_g.dot(prev_d)
207
- if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
208
- return g.dot(g - prev_g) / denom
209
-
210
- class LiuStorey(ConguateGradientBase):
211
- """Liu-Storey nonlinear conjugate gradient method.
212
-
213
- .. note::
214
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
215
- """
216
- def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
217
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
218
-
219
- def get_beta(self, p, g, prev_g, prev_d):
220
- return liu_storey_beta(g, prev_d, prev_g)
221
-
222
- # ----------------------------- Conjugate Descent ---------------------------- #
223
- class ConjugateDescent(Transform):
224
- """Conjugate Descent (CD).
225
-
226
- .. note::
227
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
228
- """
229
- def __init__(self, inner: Chainable | None = None):
230
- super().__init__(defaults={}, uses_grad=False)
231
-
232
- if inner is not None:
233
- self.set_child('inner', inner)
234
-
235
-
236
- @torch.no_grad
237
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
238
- g = as_tensorlist(tensors)
239
-
240
- prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
241
- if 'denom' not in self.global_state:
242
- self.global_state['denom'] = torch.tensor(0.).to(g[0])
243
-
244
- prev_gd = self.global_state.get('prev_gd', 0)
245
- if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
246
- else: beta = g.dot(g) / prev_gd
247
-
248
- # inner step
249
- if 'inner' in self.children:
250
- g = as_tensorlist(apply_transform(self.children['inner'], g, params, grads))
251
-
252
- dir = g.add_(prev_d.mul_(beta))
253
- prev_d.copy_(dir)
254
- self.global_state['prev_gd'] = g.dot(dir)
255
- return dir
256
-
257
-
258
- # -------------------------------- Hager-Zhang ------------------------------- #
259
- def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
260
- g_diff = g - prev_g
261
- denom = prev_d.dot(g_diff)
262
- if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
263
-
264
- term1 = 1/denom
265
- # term2
266
- term2 = (g_diff - (2 * prev_d * (g_diff.pow(2).global_sum()/denom))).dot(g)
267
- return (term1 * term2).neg()
268
-
269
-
270
- class HagerZhang(ConguateGradientBase):
271
- """Hager-Zhang nonlinear conjugate gradient method,
272
-
273
- .. note::
274
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
275
- """
276
- def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
277
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
278
-
279
- def get_beta(self, p, g, prev_g, prev_d):
280
- return hager_zhang_beta(g, prev_d, prev_g)
281
-
282
-
283
- # ----------------------------------- HS-DY ---------------------------------- #
284
- def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
285
- grad_diff = g - prev_g
286
- denom = prev_d.dot(grad_diff)
287
- if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
288
-
289
- # Dai-Yuan
290
- dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
291
-
292
- # Hestenes–Stiefel
293
- hs_beta = (g.dot(grad_diff) / denom).neg().clamp(min=0)
294
-
295
- return max(0, min(dy_beta, hs_beta)) # type:ignore
296
-
297
- class HybridHS_DY(ConguateGradientBase):
298
- """HS-DY hybrid conjugate gradient method.
299
-
300
- .. note::
301
- - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
302
- """
303
- def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
304
- super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
305
-
306
- def get_beta(self, p, g, prev_g, prev_d):
307
- return hs_dy_beta(g, prev_d, prev_g)
308
-
309
-
310
- def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
311
- Hy = H @ y
312
- yHy = _safe_clip(y.dot(Hy))
313
- H -= (Hy.outer(y) @ H) / yHy
314
- return H
315
-
316
- class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
317
- """Projected gradient method.
318
-
319
- .. note::
320
- This method uses N^2 memory.
321
-
322
- .. note::
323
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
324
-
325
- .. note::
326
- This is not the same as projected gradient descent.
327
-
328
- Reference:
329
- Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
330
-
331
- """
332
-
333
- def __init__(
334
- self,
335
- init_scale: float | Literal["auto"] = 1,
336
- tol: float = 1e-8,
337
- ptol: float | None = 1e-10,
338
- ptol_reset: bool = False,
339
- gtol: float | None = 1e-10,
340
- reset_interval: int | None | Literal['auto'] = 'auto',
341
- beta: float | None = None,
342
- update_freq: int = 1,
343
- scale_first: bool = False,
344
- scale_second: bool = False,
345
- concat_params: bool = True,
346
- # inverse: bool = True,
347
- inner: Chainable | None = None,
348
- ):
349
- super().__init__(
350
- defaults=None,
351
- init_scale=init_scale,
352
- tol=tol,
353
- ptol=ptol,
354
- ptol_reset=ptol_reset,
355
- gtol=gtol,
356
- reset_interval=reset_interval,
357
- beta=beta,
358
- update_freq=update_freq,
359
- scale_first=scale_first,
360
- scale_second=scale_second,
361
- concat_params=concat_params,
362
- inverse=True,
363
- inner=inner,
364
- )
365
-
366
-
367
-
368
- def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
369
- return projected_gradient_(H=H, y=y)
1
+ from abc import ABC, abstractmethod
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from ...core import (
7
+ Chainable,
8
+ Modular,
9
+ Module,
10
+ Transform,
11
+ Var,
12
+ apply_transform,
13
+ )
14
+ from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
15
+ from ..line_search import LineSearchBase
16
+ from ..quasi_newton.quasi_newton import HessianUpdateStrategy
17
+ from ..functional import safe_clip
18
+
19
+
20
+ class ConguateGradientBase(Transform, ABC):
21
+ """Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
22
+
23
+ This is an abstract class, to use it, subclass it and override `get_beta`.
24
+
25
+
26
+ Args:
27
+ defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
28
+ clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
29
+ restart_interval (int | None | Literal["auto"], optional):
30
+ interval between resetting the search direction.
31
+ "auto" means number of dimensions + 1, None means no reset. Defaults to None.
32
+ inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
33
+
34
+ Example:
35
+
36
+ ```python
37
+ class PolakRibiere(ConguateGradientBase):
38
+ def __init__(
39
+ self,
40
+ clip_beta=True,
41
+ restart_interval: int | None = None,
42
+ inner: Chainable | None = None
43
+ ):
44
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
45
+
46
+ def get_beta(self, p, g, prev_g, prev_d):
47
+ denom = prev_g.dot(prev_g)
48
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
49
+ return g.dot(g - prev_g) / denom
50
+ ```
51
+
52
+ """
53
+ def __init__(self, defaults = None, clip_beta: bool = False, restart_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
54
+ if defaults is None: defaults = {}
55
+ defaults['restart_interval'] = restart_interval
56
+ defaults['clip_beta'] = clip_beta
57
+ super().__init__(defaults, uses_grad=False)
58
+
59
+ if inner is not None:
60
+ self.set_child('inner', inner)
61
+
62
+
63
+ def reset_for_online(self):
64
+ super().reset_for_online()
65
+ self.clear_state_keys('prev_grad')
66
+ self.global_state.pop('stage', None)
67
+ self.global_state['step'] = self.global_state.get('step', 1) - 1
68
+
69
+ def initialize(self, p: TensorList, g: TensorList):
70
+ """runs on first step when prev_grads and prev_dir are not available"""
71
+
72
+ @abstractmethod
73
+ def get_beta(self, p: TensorList, g: TensorList, prev_g: TensorList, prev_d: TensorList) -> float | torch.Tensor:
74
+ """returns beta"""
75
+
76
+ @torch.no_grad
77
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
78
+ tensors = as_tensorlist(tensors)
79
+ params = as_tensorlist(params)
80
+
81
+ step = self.global_state.get('step', 0) + 1
82
+ self.global_state['step'] = step
83
+
84
+ # initialize on first step
85
+ if self.global_state.get('stage', 0) == 0:
86
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
87
+ d_prev.copy_(tensors)
88
+ g_prev.copy_(tensors)
89
+ self.initialize(params, tensors)
90
+ self.global_state['stage'] = 1
91
+
92
+ else:
93
+ # if `update_tensors` was called multiple times before `apply_tensors`,
94
+ # stage becomes 2
95
+ self.global_state['stage'] = 2
96
+
97
+ @torch.no_grad
98
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
99
+ tensors = as_tensorlist(tensors)
100
+ step = self.global_state['step']
101
+
102
+ if 'inner' in self.children:
103
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
104
+
105
+ assert self.global_state['stage'] != 0
106
+ if self.global_state['stage'] == 1:
107
+ self.global_state['stage'] = 2
108
+ return tensors
109
+
110
+ params = as_tensorlist(params)
111
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
112
+
113
+ # get beta
114
+ beta = self.get_beta(params, tensors, g_prev, d_prev)
115
+ if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
116
+
117
+ # inner step
118
+ # calculate new direction with beta
119
+ dir = tensors.add_(d_prev.mul_(beta))
120
+ d_prev.copy_(dir)
121
+
122
+ # resetting
123
+ restart_interval = settings[0]['restart_interval']
124
+ if restart_interval == 'auto': restart_interval = tensors.global_numel() + 1
125
+ if restart_interval is not None and step % restart_interval == 0:
126
+ self.state.clear()
127
+ self.global_state.clear()
128
+
129
+ return dir
130
+
131
+ # ------------------------------- Polak-Ribière ------------------------------ #
132
+ def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
133
+ denom = prev_g.dot(prev_g)
134
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
135
+ return g.dot(g - prev_g) / denom
136
+
137
+ class PolakRibiere(ConguateGradientBase):
138
+ """Polak-Ribière-Polyak nonlinear conjugate gradient method.
139
+
140
+ Note:
141
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
142
+ """
143
+ def __init__(self, clip_beta=True, restart_interval: int | None = None, inner: Chainable | None = None):
144
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
145
+
146
+ def get_beta(self, p, g, prev_g, prev_d):
147
+ return polak_ribiere_beta(g, prev_g)
148
+
149
+ # ------------------------------ Fletcher–Reeves ----------------------------- #
150
+ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
151
+ if prev_gg.abs() <= torch.finfo(gg.dtype).tiny * 2: return 0
152
+ return gg / prev_gg
153
+
154
+ class FletcherReeves(ConguateGradientBase):
155
+ """Fletcher–Reeves nonlinear conjugate gradient method.
156
+
157
+ Note:
158
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
159
+ """
160
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
161
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
162
+
163
+ def initialize(self, p, g):
164
+ self.global_state['prev_gg'] = g.dot(g)
165
+
166
+ def get_beta(self, p, g, prev_g, prev_d):
167
+ gg = g.dot(g)
168
+ beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
169
+ self.global_state['prev_gg'] = gg
170
+ return beta
171
+
172
+ # ----------------------------- Hestenes–Stiefel ----------------------------- #
173
+ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
174
+ grad_diff = g - prev_g
175
+ denom = prev_d.dot(grad_diff)
176
+ if denom.abs() < torch.finfo(g[0].dtype).tiny * 2: return 0
177
+ return (g.dot(grad_diff) / denom).neg()
178
+
179
+
180
+ class HestenesStiefel(ConguateGradientBase):
181
+ """Hestenes–Stiefel nonlinear conjugate gradient method.
182
+
183
+ Note:
184
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
185
+ """
186
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
187
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
188
+
189
+ def get_beta(self, p, g, prev_g, prev_d):
190
+ return hestenes_stiefel_beta(g, prev_d, prev_g)
191
+
192
+
193
+ # --------------------------------- Dai–Yuan --------------------------------- #
194
+ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
195
+ denom = prev_d.dot(g - prev_g)
196
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
197
+ return (g.dot(g) / denom).neg()
198
+
199
+ class DaiYuan(ConguateGradientBase):
200
+ """Dai–Yuan nonlinear conjugate gradient method.
201
+
202
+ Note:
203
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1)`` after this.
204
+ """
205
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
206
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
207
+
208
+ def get_beta(self, p, g, prev_g, prev_d):
209
+ return dai_yuan_beta(g, prev_d, prev_g)
210
+
211
+
212
+ # -------------------------------- Liu-Storey -------------------------------- #
213
+ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
214
+ denom = prev_g.dot(prev_d)
215
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
216
+ return g.dot(g - prev_g) / denom
217
+
218
+ class LiuStorey(ConguateGradientBase):
219
+ """Liu-Storey nonlinear conjugate gradient method.
220
+
221
+ Note:
222
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
223
+ """
224
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
225
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
226
+
227
+ def get_beta(self, p, g, prev_g, prev_d):
228
+ return liu_storey_beta(g, prev_d, prev_g)
229
+
230
+ # ----------------------------- Conjugate Descent ---------------------------- #
231
+ def conjugate_descent_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList):
232
+ denom = prev_g.dot(prev_d)
233
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
234
+ return g.dot(g) / denom
235
+
236
+ class ConjugateDescent(ConguateGradientBase):
237
+ """Conjugate Descent (CD).
238
+
239
+ Note:
240
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
241
+ """
242
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
243
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
244
+
245
+ def get_beta(self, p, g, prev_g, prev_d):
246
+ return conjugate_descent_beta(g, prev_d, prev_g)
247
+
248
+
249
+ # -------------------------------- Hager-Zhang ------------------------------- #
250
+ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
251
+ g_diff = g - prev_g
252
+ denom = prev_d.dot(g_diff)
253
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
254
+
255
+ term1 = 1/denom
256
+ # term2
257
+ term2 = (g_diff - (2 * prev_d * (g_diff.pow(2).global_sum()/denom))).dot(g)
258
+ return (term1 * term2).neg()
259
+
260
+
261
+ class HagerZhang(ConguateGradientBase):
262
+ """Hager-Zhang nonlinear conjugate gradient method,
263
+
264
+ Note:
265
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
266
+ """
267
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
268
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
269
+
270
+ def get_beta(self, p, g, prev_g, prev_d):
271
+ return hager_zhang_beta(g, prev_d, prev_g)
272
+
273
+
274
+ # ----------------------------------- DYHS ---------------------------------- #
275
+ def dyhs_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
276
+ grad_diff = g - prev_g
277
+ denom = prev_d.dot(grad_diff)
278
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
279
+
280
+ # Dai-Yuan
281
+ dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
282
+
283
+ # Hestenes–Stiefel
284
+ hs_beta = (g.dot(grad_diff) / denom).neg().clamp(min=0)
285
+
286
+ return max(0, min(dy_beta, hs_beta)) # type:ignore
287
+
288
+ class DYHS(ConguateGradientBase):
289
+ """Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.
290
+
291
+ Note:
292
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
293
+ """
294
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
295
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
296
+
297
+ def get_beta(self, p, g, prev_g, prev_d):
298
+ return dyhs_beta(g, prev_d, prev_g)
299
+
300
+
301
+ def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
302
+ Hy = H @ y
303
+ yHy = safe_clip(y.dot(Hy))
304
+ H -= (Hy.outer(y) @ H) / yHy
305
+ return H
306
+
307
+ class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
308
+ """Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.
309
+
310
+ Notes:
311
+ - This method uses N^2 memory.
312
+ - This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
313
+ - This is not the same as projected gradient descent.
314
+
315
+ Reference:
316
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171. (algorithm 5 in section 6)
317
+
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ init_scale: float | Literal["auto"] = 1,
323
+ tol: float = 1e-32,
324
+ ptol: float | None = 1e-32,
325
+ ptol_restart: bool = False,
326
+ gtol: float | None = 1e-32,
327
+ restart_interval: int | None | Literal['auto'] = 'auto',
328
+ beta: float | None = None,
329
+ update_freq: int = 1,
330
+ scale_first: bool = False,
331
+ concat_params: bool = True,
332
+ # inverse: bool = True,
333
+ inner: Chainable | None = None,
334
+ ):
335
+ super().__init__(
336
+ defaults=None,
337
+ init_scale=init_scale,
338
+ tol=tol,
339
+ ptol=ptol,
340
+ ptol_restart=ptol_restart,
341
+ gtol=gtol,
342
+ restart_interval=restart_interval,
343
+ beta=beta,
344
+ update_freq=update_freq,
345
+ scale_first=scale_first,
346
+ concat_params=concat_params,
347
+ inverse=True,
348
+ inner=inner,
349
+ )
350
+
351
+
352
+
353
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
354
+ return projected_gradient_(H=H, y=y)
355
+