torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,300 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable, Sequence
4
+ from contextlib import nullcontext
5
+ from functools import partial
6
+ from typing import Literal, cast
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Modular, Module, Var
11
+ from ...core.reformulation import Reformulation
12
+ from ...utils import Distributions, NumberList, TensorList
13
+ from ..termination import TerminationCriteriaBase, make_termination_criteria
14
+
15
+
16
+ def _reset_except_self(optimizer: Modular, var: Var, self: Module):
17
+ for m in optimizer.unrolled_modules:
18
+ if m is not self:
19
+ m.reset()
20
+
21
+
22
+ class GradientSampling(Reformulation):
23
+ """Samples and aggregates gradients and values at perturbed points.
24
+
25
+ This module can be used for gaussian homotopy and gradient sampling methods.
26
+
27
+ Args:
28
+ modules (Chainable | None, optional):
29
+ modules that will be optimizing the modified objective.
30
+ if None, returns gradient of the modified objective as the update. Defaults to None.
31
+ sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
32
+ n (int, optional): number of perturbations per step. Defaults to 100.
33
+ aggregate (str, optional):
34
+ how to aggregate values and gradients
35
+ - "mean" - uses mean of the gradients, as in gaussian homotopy.
36
+ - "max" - uses element-wise maximum of the gradients.
37
+ - "min" - uses element-wise minimum of the gradients.
38
+ - "min-norm" - picks gradient with the lowest norm.
39
+
40
+ Defaults to 'mean'.
41
+ distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
42
+ include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
43
+ fixed (bool, optional):
44
+ if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
45
+ pre_generate (bool, optional):
46
+ if True, perturbations are pre-generated before each step.
47
+ This requires more memory to store all of them,
48
+ but ensures they do not change when closure is evaluated multiple times.
49
+ Defaults to True.
50
+ termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
51
+ a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
52
+ and new perturbations will be generated if ``fixed``. Defaults to None.
53
+ decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
54
+ reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
55
+ sigma_strategy (str | None, optional):
56
+ strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
57
+ otherwise it is multiplied by ``sigma_nminus``.
58
+ - "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
59
+ - "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
60
+ - None - doesn't use adaptive sigma.
61
+
62
+ This introduces a side-effect to the closure, so it should be left at None of you use
63
+ trust region or line search to optimize the modified objective.
64
+ Defaults to None.
65
+ sigma_target (int, optional):
66
+ number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
67
+ sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
68
+ sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
69
+ seed (int | None, optional): seed. Defaults to None.
70
+ """
71
+ def __init__(
72
+ self,
73
+ modules: Chainable | None = None,
74
+ sigma: float = 1.,
75
+ n:int = 100,
76
+ aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
77
+ distribution: Distributions = 'gaussian',
78
+ include_x0: bool = True,
79
+
80
+ fixed: bool=True,
81
+ pre_generate: bool = True,
82
+ termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
83
+ decay: float = 2/3,
84
+ reset_on_termination: bool = True,
85
+
86
+ sigma_strategy: Literal['grad-norm', 'value'] | None = None,
87
+ sigma_target: int | float = 0.2,
88
+ sigma_nplus: float = 4/3,
89
+ sigma_nminus: float = 2/3,
90
+
91
+ seed: int | None = None,
92
+ ):
93
+
94
+ defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
95
+ super().__init__(defaults, modules)
96
+
97
+ if termination is not None:
98
+ self.set_child('termination', make_termination_criteria(extra=termination))
99
+
100
+ @torch.no_grad
101
+ def pre_step(self, var):
102
+ params = TensorList(var.params)
103
+
104
+ fixed = self.defaults['fixed']
105
+
106
+ # check termination criteria
107
+ if 'termination' in self.children:
108
+ termination = cast(TerminationCriteriaBase, self.children['termination'])
109
+ if termination.should_terminate(var):
110
+
111
+ # decay sigmas
112
+ states = [self.state[p] for p in params]
113
+ settings = [self.settings[p] for p in params]
114
+
115
+ for state, setting in zip(states, settings):
116
+ if 'sigma' not in state: state['sigma'] = setting['sigma']
117
+ state['sigma'] *= setting['decay']
118
+
119
+ # reset on sigmas decay
120
+ if self.defaults['reset_on_termination']:
121
+ var.post_step_hooks.append(partial(_reset_except_self, self=self))
122
+
123
+ # clear perturbations
124
+ self.global_state.pop('perts', None)
125
+
126
+ # pre-generate perturbations if not already pre-generated or not fixed
127
+ if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
128
+ states = [self.state[p] for p in params]
129
+ settings = [self.settings[p] for p in params]
130
+
131
+ n = self.defaults['n'] - self.defaults['include_x0']
132
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
133
+
134
+ perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]
135
+
136
+ self.global_state['perts'] = perts
137
+
138
+ @torch.no_grad
139
+ def closure(self, backward, closure, params, var):
140
+ params = TensorList(params)
141
+ loss_agg = None
142
+ grad_agg = None
143
+
144
+ states = [self.state[p] for p in params]
145
+ settings = [self.settings[p] for p in params]
146
+ sigma_inits = [s['sigma'] for s in settings]
147
+ sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]
148
+
149
+ include_x0 = self.defaults['include_x0']
150
+ pre_generate = self.defaults['pre_generate']
151
+ aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
152
+ sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
153
+ distribution = self.defaults['distribution']
154
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
155
+
156
+
157
+ n_finite = 0
158
+ n_good = 0
159
+ f_0 = None; g_0 = None
160
+
161
+ # evaluate at x_0
162
+ if include_x0:
163
+ f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
164
+
165
+ isfinite = math.isfinite(f_0)
166
+ if isfinite:
167
+ n_finite += 1
168
+ loss_agg = f_0
169
+
170
+ if backward:
171
+ g_0 = var.get_grad()
172
+ if isfinite: grad_agg = g_0
173
+
174
+ # evaluate at x_0 + p for each perturbation
175
+ if pre_generate:
176
+ perts = self.global_state['perts']
177
+ else:
178
+ perts = [None] * (self.defaults['n'] - include_x0)
179
+
180
+ x_0 = [p.clone() for p in params]
181
+
182
+ for pert in perts:
183
+ loss = None; grad = None
184
+
185
+ # generate if not pre-generated
186
+ if pert is None:
187
+ pert = params.sample_like(distribution, generator=generator)
188
+
189
+ # add perturbation and evaluate
190
+ pert = pert * sigmas
191
+ torch._foreach_add_(params, pert)
192
+
193
+ with torch.enable_grad() if backward else nullcontext():
194
+ loss = closure(backward)
195
+
196
+ if math.isfinite(loss):
197
+ n_finite += 1
198
+
199
+ # add loss
200
+ if loss_agg is None:
201
+ loss_agg = loss
202
+ else:
203
+ if aggregate == 'mean':
204
+ loss_agg += loss
205
+
206
+ elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
207
+ loss_agg = loss_agg.clamp(max=loss)
208
+
209
+ elif aggregate == 'max':
210
+ loss_agg = loss_agg.clamp(min=loss)
211
+
212
+ # add grad
213
+ if backward:
214
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
215
+ if grad_agg is None:
216
+ grad_agg = grad
217
+ else:
218
+ if aggregate == 'mean':
219
+ torch._foreach_add_(grad_agg, grad)
220
+
221
+ elif aggregate == 'min':
222
+ grad_agg_abs = torch._foreach_abs(grad_agg)
223
+ torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
224
+ grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
225
+
226
+ elif aggregate == 'max':
227
+ grad_agg_abs = torch._foreach_abs(grad_agg)
228
+ torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
229
+ grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
230
+
231
+ elif aggregate == 'min-norm':
232
+ if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
233
+ grad_agg = grad
234
+ loss_agg = loss
235
+
236
+ elif aggregate == 'min-value':
237
+ if loss < loss_agg:
238
+ grad_agg = grad
239
+ loss_agg = loss
240
+
241
+ # undo perturbation
242
+ torch._foreach_copy_(params, x_0)
243
+
244
+ # adaptive sigma
245
+ # by value
246
+ if sigma_strategy == 'value':
247
+ if f_0 is None:
248
+ with torch.enable_grad() if backward else nullcontext():
249
+ f_0 = closure(False)
250
+
251
+ if loss < f_0:
252
+ n_good += 1
253
+
254
+ # by gradient norm
255
+ elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
256
+ assert grad is not None
257
+ if g_0 is None:
258
+ with torch.enable_grad() if backward else nullcontext():
259
+ closure()
260
+ g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
261
+
262
+ if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
263
+ n_good += 1
264
+
265
+ # update sigma if strategy is enabled
266
+ if sigma_strategy is not None:
267
+
268
+ sigma_target = self.defaults['sigma_target']
269
+ if isinstance(sigma_target, float):
270
+ sigma_target = int(max(1, n_finite * sigma_target))
271
+
272
+ if n_good >= sigma_target:
273
+ key = 'sigma_nplus'
274
+ else:
275
+ key = 'sigma_nminus'
276
+
277
+ for p in params:
278
+ self.state[p]['sigma'] *= self.settings[p][key]
279
+
280
+ # if no finite losses, just return inf
281
+ if n_finite == 0:
282
+ assert loss_agg is None and grad_agg is None
283
+ loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
284
+ grad = [torch.full_like(p, torch.inf) for p in params]
285
+ return loss, grad
286
+
287
+ assert loss_agg is not None
288
+
289
+ # no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
290
+ if aggregate != 'mean':
291
+ return loss_agg, grad_agg
292
+
293
+ # on mean divide by number of evals
294
+ loss_agg /= n_finite
295
+
296
+ if backward:
297
+ assert grad_agg is not None
298
+ torch._foreach_div_(grad_agg, n_finite)
299
+
300
+ return loss_agg, grad_agg
@@ -0,0 +1,2 @@
1
+ from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
2
+ from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD
@@ -0,0 +1,387 @@
1
+ """Various step size strategies"""
2
+ import math
3
+ from operator import itemgetter
4
+ from typing import Any, Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Transform
9
+ from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
10
+ from ...utils.linalg.linear_operator import ScaledIdentity
11
+ from ..functional import epsilon_step_size
12
+
13
+ def _acceptable_alpha(alpha, param:torch.Tensor):
14
+ finfo = torch.finfo(param.dtype)
15
+ if (alpha is None) or (alpha < finfo.tiny*2) or (not math.isfinite(alpha)) or (alpha > finfo.max/2):
16
+ return False
17
+ return True
18
+
19
+ def _get_H(self: Transform, var):
20
+ n = sum(p.numel() for p in var.params)
21
+ p = var.params[0]
22
+ alpha = self.global_state.get('alpha', 1)
23
+ if not _acceptable_alpha(alpha, p): alpha = 1
24
+
25
+ return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
26
+
27
+
28
+ class PolyakStepSize(Transform):
29
+ """Polyak's subgradient method with known or unknown f*.
30
+
31
+ Args:
32
+ f_star (float | Mone, optional):
33
+ minimal possible value of the objective function. If not known, set to ``None``. Defaults to 0.
34
+ y (float, optional):
35
+ when ``f_star`` is set to None, it is calculated as ``f_best - y``.
36
+ y_decay (float, optional):
37
+ ``y`` is multiplied by ``(1 - y_decay)`` after each step. Defaults to 1e-3.
38
+ max (float | None, optional): maximum possible step size. Defaults to None.
39
+ use_grad (bool, optional):
40
+ if True, uses dot product of update and gradient to compute the step size.
41
+ Otherwise, dot product of update with itself is used.
42
+ alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
43
+ """
44
+ def __init__(self, f_star: float | None = 0, y: float = 1, y_decay: float = 1e-3, max: float | None = None, use_grad=True, alpha: float = 1, inner: Chainable | None = None):
45
+
46
+ defaults = dict(alpha=alpha, max=max, f_star=f_star, y=y, y_decay=y_decay)
47
+ super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
48
+
49
+ @torch.no_grad
50
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
51
+ assert grads is not None and loss is not None
52
+ tensors = TensorList(tensors)
53
+ grads = TensorList(grads)
54
+
55
+ # load variables
56
+ max, f_star, y, y_decay = itemgetter('max', 'f_star', 'y', 'y_decay')(settings[0])
57
+ y_val = self.global_state.get('y_val', y)
58
+ f_best = self.global_state.get('f_best', None)
59
+
60
+ # gg
61
+ if self._uses_grad: gg = tensors.dot(grads)
62
+ else: gg = tensors.dot(tensors)
63
+
64
+ # store loss
65
+ if f_best is None or loss < f_best: f_best = tofloat(loss)
66
+ if f_star is None: f_star = f_best - y_val
67
+
68
+ # calculate the step size
69
+ if gg <= torch.finfo(gg.dtype).tiny * 2: alpha = 0 # converged
70
+ else: alpha = (loss - f_star) / gg
71
+
72
+ # clip
73
+ if max is not None:
74
+ if alpha > max: alpha = max
75
+
76
+ # store state
77
+ self.global_state['f_best'] = f_best
78
+ self.global_state['y_val'] = y_val * (1 - y_decay)
79
+ self.global_state['alpha'] = alpha
80
+
81
+ @torch.no_grad
82
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
83
+ alpha = self.global_state.get('alpha', 1)
84
+ if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
85
+
86
+ torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
87
+ return tensors
88
+
89
+ def get_H(self, var):
90
+ return _get_H(self, var)
91
+
92
+
93
+ def _bb_short(s: TensorList, y: TensorList, sy, eps):
94
+ yy = y.dot(y)
95
+ if yy < eps:
96
+ if sy < eps: return None # try to fallback on long
97
+ ss = s.dot(s)
98
+ return ss/sy
99
+ return sy/yy
100
+
101
+ def _bb_long(s: TensorList, y: TensorList, sy, eps):
102
+ ss = s.dot(s)
103
+ if sy < eps:
104
+ yy = y.dot(y) # try to fallback on short
105
+ if yy < eps: return None
106
+ return sy/yy
107
+ return ss/sy
108
+
109
+ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
110
+ short = _bb_short(s, y, sy, eps)
111
+ long = _bb_long(s, y, sy, eps)
112
+ if long is None or short is None:
113
+ if fallback:
114
+ if short is not None: return short
115
+ if long is not None: return long
116
+ return None
117
+ return (short * long) ** 0.5
118
+
119
+ class BarzilaiBorwein(Transform):
120
+ """Barzilai-Borwein step size method.
121
+
122
+ Args:
123
+ type (str, optional):
124
+ one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
125
+ Defaults to "geom".
126
+ fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
127
+ inner (Chainable | None, optional):
128
+ step size will be applied to outputs of this module. Defaults to None.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
134
+ alpha_0: float = 1e-7,
135
+ use_grad=True,
136
+ inner: Chainable | None = None,
137
+ ):
138
+ defaults = dict(type=type, alpha_0=alpha_0)
139
+ super().__init__(defaults, uses_grad=use_grad, inner=inner)
140
+
141
+ def reset_for_online(self):
142
+ super().reset_for_online()
143
+ self.clear_state_keys('prev_g')
144
+ self.global_state['reset'] = True
145
+
146
+ @torch.no_grad
147
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
148
+ step = self.global_state.get('step', 0)
149
+ self.global_state['step'] = step + 1
150
+
151
+ prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
152
+ type = self.defaults['type']
153
+
154
+ g = grads if self._uses_grad else tensors
155
+ assert g is not None
156
+
157
+ reset = self.global_state.get('reset', False)
158
+ self.global_state.pop('reset', None)
159
+
160
+ if step != 0 and not reset:
161
+ s = params-prev_p
162
+ y = g-prev_g
163
+ sy = s.dot(y)
164
+ eps = torch.finfo(sy.dtype).tiny * 2
165
+
166
+ if type == 'short': alpha = _bb_short(s, y, sy, eps)
167
+ elif type == 'long': alpha = _bb_long(s, y, sy, eps)
168
+ elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
169
+ elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
170
+ else: raise ValueError(type)
171
+
172
+ # if alpha is not None:
173
+ self.global_state['alpha'] = alpha
174
+
175
+ prev_p.copy_(params)
176
+ prev_g.copy_(g)
177
+
178
+ def get_H(self, var):
179
+ return _get_H(self, var)
180
+
181
+ @torch.no_grad
182
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
183
+ alpha = self.global_state.get('alpha', None)
184
+
185
+ if not _acceptable_alpha(alpha, tensors[0]):
186
+ alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
187
+
188
+ torch._foreach_mul_(tensors, alpha)
189
+ return tensors
190
+
191
+
192
+ class BBStab(Transform):
193
+ """Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
194
+
195
+ This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
196
+
197
+ Args:
198
+ c (float, optional):
199
+ adaptive delta parameter. If ``delta`` is set to None, first ``inf_iters`` updates are performed
200
+ with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of
201
+ the update that had the smallest norm, and multiplied by ``c``. Defaults to 0.2.
202
+ delta (float | None, optional):
203
+ Barzilai-Borwein update is clipped to this value. Set to ``None`` to use an adaptive choice. Defaults to None.
204
+ type (str, optional):
205
+ one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
206
+ Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab,
207
+ however I found that "geom" works really well.
208
+ inner (Chainable | None, optional):
209
+ step size will be applied to outputs of this module. Defaults to None.
210
+
211
+ """
212
+ def __init__(
213
+ self,
214
+ c=0.2,
215
+ delta:float | None = None,
216
+ type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
217
+ alpha_0: float = 1e-7,
218
+ use_grad=True,
219
+ inf_iters: int = 3,
220
+ inner: Chainable | None = None,
221
+ ):
222
+ defaults = dict(type=type,alpha_0=alpha_0, c=c, delta=delta, inf_iters=inf_iters)
223
+ super().__init__(defaults, uses_grad=use_grad, inner=inner)
224
+
225
+ def reset_for_online(self):
226
+ super().reset_for_online()
227
+ self.clear_state_keys('prev_g')
228
+ self.global_state['reset'] = True
229
+
230
+ @torch.no_grad
231
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
232
+ step = self.global_state.get('step', 0)
233
+ self.global_state['step'] = step + 1
234
+
235
+ prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
236
+ type = self.defaults['type']
237
+ c = self.defaults['c']
238
+ delta = self.defaults['delta']
239
+ inf_iters = self.defaults['inf_iters']
240
+
241
+ g = grads if self._uses_grad else tensors
242
+ assert g is not None
243
+ g = TensorList(g)
244
+
245
+ reset = self.global_state.get('reset', False)
246
+ self.global_state.pop('reset', None)
247
+
248
+ if step != 0 and not reset:
249
+ s = params-prev_p
250
+ y = g-prev_g
251
+ sy = s.dot(y)
252
+ eps = torch.finfo(sy.dtype).tiny
253
+
254
+ if type == 'short': alpha = _bb_short(s, y, sy, eps)
255
+ elif type == 'long': alpha = _bb_long(s, y, sy, eps)
256
+ elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
257
+ elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
258
+ else: raise ValueError(type)
259
+
260
+ if alpha is not None:
261
+
262
+ # adaptive delta
263
+ if delta is None:
264
+ niters = self.global_state.get('niters', 0) # this accounts for skipped negative curvature steps
265
+ self.global_state['niters'] = niters + 1
266
+
267
+
268
+ if niters == 0: pass # 1st iteration is scaled GD step, shouldn't be used to find s_norm_min
269
+ elif niters <= inf_iters:
270
+ s_norm_min = self.global_state.get('s_norm_min', None)
271
+ if s_norm_min is None: s_norm_min = s.global_vector_norm()
272
+ else: s_norm_min = min(s_norm_min, s.global_vector_norm())
273
+ self.global_state['s_norm_min'] = s_norm_min
274
+ # first few steps use delta=inf, so delta remains None
275
+
276
+ else:
277
+ delta = c * self.global_state['s_norm_min']
278
+
279
+ if delta is None: # delta is inf for first few steps
280
+ self.global_state['alpha'] = alpha
281
+
282
+ # BBStab step size
283
+ else:
284
+ a_stab = delta / g.global_vector_norm()
285
+ self.global_state['alpha'] = min(alpha, a_stab)
286
+
287
+ prev_p.copy_(params)
288
+ prev_g.copy_(g)
289
+
290
+ def get_H(self, var):
291
+ return _get_H(self, var)
292
+
293
+ @torch.no_grad
294
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
295
+ alpha = self.global_state.get('alpha', None)
296
+
297
+ if not _acceptable_alpha(alpha, tensors[0]):
298
+ alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
299
+
300
+ torch._foreach_mul_(tensors, alpha)
301
+ return tensors
302
+
303
+
304
+ class AdGD(Transform):
305
+ """AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
306
+ def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
307
+ defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
308
+ super().__init__(defaults, uses_grad=use_grad, inner=inner,)
309
+
310
+ def reset_for_online(self):
311
+ super().reset_for_online()
312
+ self.clear_state_keys('prev_g')
313
+ self.global_state['reset'] = True
314
+
315
+ @torch.no_grad
316
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
317
+ variant = settings[0]['variant']
318
+ theta_0 = 0 if variant == 1 else 1/3
319
+ theta = self.global_state.get('theta', theta_0)
320
+
321
+ step = self.global_state.get('step', 0)
322
+ self.global_state['step'] = step + 1
323
+
324
+ p = TensorList(params)
325
+ g = grads if self._uses_grad else tensors
326
+ assert g is not None
327
+ g = TensorList(g)
328
+
329
+ prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
330
+
331
+ # online
332
+ if self.global_state.get('reset', False):
333
+ del self.global_state['reset']
334
+ prev_p.copy_(p)
335
+ prev_g.copy_(g)
336
+ return
337
+
338
+ if step == 0:
339
+ alpha_0 = settings[0]['alpha_0']
340
+ if alpha_0 is None: alpha_0 = epsilon_step_size(g)
341
+ self.global_state['alpha'] = alpha_0
342
+ prev_p.copy_(p)
343
+ prev_g.copy_(g)
344
+ return
345
+
346
+ sqrt = settings[0]['sqrt']
347
+ alpha = self.global_state.get('alpha', math.inf)
348
+ L = (g - prev_g).global_vector_norm() / (p - prev_p).global_vector_norm()
349
+ eps = torch.finfo(L.dtype).tiny * 2
350
+
351
+ if variant == 1:
352
+ a1 = math.sqrt(1 + theta)*alpha
353
+ val = math.sqrt(2) if sqrt else 2
354
+ if L > eps: a2 = 1 / (val*L)
355
+ else: a2 = math.inf
356
+
357
+ elif variant == 2:
358
+ a1 = math.sqrt(2/3 + theta)*alpha
359
+ a2 = alpha / math.sqrt(max(eps, 2 * alpha**2 * L**2 - 1))
360
+
361
+ else:
362
+ raise ValueError(variant)
363
+
364
+ alpha_new = min(a1, a2)
365
+ if alpha_new < 0: alpha_new = max(a1, a2)
366
+ if alpha_new > eps:
367
+ self.global_state['theta'] = alpha_new/alpha
368
+ self.global_state['alpha'] = alpha_new
369
+
370
+ prev_p.copy_(p)
371
+ prev_g.copy_(g)
372
+
373
+ @torch.no_grad
374
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
375
+ alpha = self.global_state.get('alpha', None)
376
+
377
+ if not _acceptable_alpha(alpha, tensors[0]):
378
+ # alpha isn't None on 1st step
379
+ self.state.clear()
380
+ self.global_state.clear()
381
+ alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
382
+
383
+ torch._foreach_mul_(tensors, alpha)
384
+ return tensors
385
+
386
+ def get_H(self, var):
387
+ return _get_H(self, var)