torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -0,0 +1,300 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable, Sequence
4
+ from contextlib import nullcontext
5
+ from functools import partial
6
+ from typing import Literal, cast
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Modular, Module, Var
11
+ from ...core.reformulation import Reformulation
12
+ from ...utils import Distributions, NumberList, TensorList
13
+ from ..termination import TerminationCriteriaBase, make_termination_criteria
14
+
15
+
16
+ def _reset_except_self(optimizer: Modular, var: Var, self: Module):
17
+ for m in optimizer.unrolled_modules:
18
+ if m is not self:
19
+ m.reset()
20
+
21
+
22
+ class GradientSampling(Reformulation):
23
+ """Samples and aggregates gradients and values at perturbed points.
24
+
25
+ This module can be used for gaussian homotopy and gradient sampling methods.
26
+
27
+ Args:
28
+ modules (Chainable | None, optional):
29
+ modules that will be optimizing the modified objective.
30
+ if None, returns gradient of the modified objective as the update. Defaults to None.
31
+ sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
32
+ n (int, optional): number of perturbations per step. Defaults to 100.
33
+ aggregate (str, optional):
34
+ how to aggregate values and gradients
35
+ - "mean" - uses mean of the gradients, as in gaussian homotopy.
36
+ - "max" - uses element-wise maximum of the gradients.
37
+ - "min" - uses element-wise minimum of the gradients.
38
+ - "min-norm" - picks gradient with the lowest norm.
39
+
40
+ Defaults to 'mean'.
41
+ distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
42
+ include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
43
+ fixed (bool, optional):
44
+ if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
45
+ pre_generate (bool, optional):
46
+ if True, perturbations are pre-generated before each step.
47
+ This requires more memory to store all of them,
48
+ but ensures they do not change when closure is evaluated multiple times.
49
+ Defaults to True.
50
+ termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
51
+ a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
52
+ and new perturbations will be generated if ``fixed``. Defaults to None.
53
+ decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
54
+ reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
55
+ sigma_strategy (str | None, optional):
56
+ strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
57
+ otherwise it is multiplied by ``sigma_nminus``.
58
+ - "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
59
+ - "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
60
+ - None - doesn't use adaptive sigma.
61
+
62
+ This introduces a side-effect to the closure, so it should be left at None of you use
63
+ trust region or line search to optimize the modified objective.
64
+ Defaults to None.
65
+ sigma_target (int, optional):
66
+ number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
67
+ sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
68
+ sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
69
+ seed (int | None, optional): seed. Defaults to None.
70
+ """
71
+ def __init__(
72
+ self,
73
+ modules: Chainable | None = None,
74
+ sigma: float = 1.,
75
+ n:int = 100,
76
+ aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
77
+ distribution: Distributions = 'gaussian',
78
+ include_x0: bool = True,
79
+
80
+ fixed: bool=True,
81
+ pre_generate: bool = True,
82
+ termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
83
+ decay: float = 2/3,
84
+ reset_on_termination: bool = True,
85
+
86
+ sigma_strategy: Literal['grad-norm', 'value'] | None = None,
87
+ sigma_target: int | float = 0.2,
88
+ sigma_nplus: float = 4/3,
89
+ sigma_nminus: float = 2/3,
90
+
91
+ seed: int | None = None,
92
+ ):
93
+
94
+ defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
95
+ super().__init__(defaults, modules)
96
+
97
+ if termination is not None:
98
+ self.set_child('termination', make_termination_criteria(extra=termination))
99
+
100
+ @torch.no_grad
101
+ def pre_step(self, var):
102
+ params = TensorList(var.params)
103
+
104
+ fixed = self.defaults['fixed']
105
+
106
+ # check termination criteria
107
+ if 'termination' in self.children:
108
+ termination = cast(TerminationCriteriaBase, self.children['termination'])
109
+ if termination.should_terminate(var):
110
+
111
+ # decay sigmas
112
+ states = [self.state[p] for p in params]
113
+ settings = [self.settings[p] for p in params]
114
+
115
+ for state, setting in zip(states, settings):
116
+ if 'sigma' not in state: state['sigma'] = setting['sigma']
117
+ state['sigma'] *= setting['decay']
118
+
119
+ # reset on sigmas decay
120
+ if self.defaults['reset_on_termination']:
121
+ var.post_step_hooks.append(partial(_reset_except_self, self=self))
122
+
123
+ # clear perturbations
124
+ self.global_state.pop('perts', None)
125
+
126
+ # pre-generate perturbations if not already pre-generated or not fixed
127
+ if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
128
+ states = [self.state[p] for p in params]
129
+ settings = [self.settings[p] for p in params]
130
+
131
+ n = self.defaults['n'] - self.defaults['include_x0']
132
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
133
+
134
+ perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]
135
+
136
+ self.global_state['perts'] = perts
137
+
138
+ @torch.no_grad
139
+ def closure(self, backward, closure, params, var):
140
+ params = TensorList(params)
141
+ loss_agg = None
142
+ grad_agg = None
143
+
144
+ states = [self.state[p] for p in params]
145
+ settings = [self.settings[p] for p in params]
146
+ sigma_inits = [s['sigma'] for s in settings]
147
+ sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]
148
+
149
+ include_x0 = self.defaults['include_x0']
150
+ pre_generate = self.defaults['pre_generate']
151
+ aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
152
+ sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
153
+ distribution = self.defaults['distribution']
154
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
155
+
156
+
157
+ n_finite = 0
158
+ n_good = 0
159
+ f_0 = None; g_0 = None
160
+
161
+ # evaluate at x_0
162
+ if include_x0:
163
+ f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
164
+
165
+ isfinite = math.isfinite(f_0)
166
+ if isfinite:
167
+ n_finite += 1
168
+ loss_agg = f_0
169
+
170
+ if backward:
171
+ g_0 = var.get_grad()
172
+ if isfinite: grad_agg = g_0
173
+
174
+ # evaluate at x_0 + p for each perturbation
175
+ if pre_generate:
176
+ perts = self.global_state['perts']
177
+ else:
178
+ perts = [None] * (self.defaults['n'] - include_x0)
179
+
180
+ x_0 = [p.clone() for p in params]
181
+
182
+ for pert in perts:
183
+ loss = None; grad = None
184
+
185
+ # generate if not pre-generated
186
+ if pert is None:
187
+ pert = params.sample_like(distribution, generator=generator)
188
+
189
+ # add perturbation and evaluate
190
+ pert = pert * sigmas
191
+ torch._foreach_add_(params, pert)
192
+
193
+ with torch.enable_grad() if backward else nullcontext():
194
+ loss = closure(backward)
195
+
196
+ if math.isfinite(loss):
197
+ n_finite += 1
198
+
199
+ # add loss
200
+ if loss_agg is None:
201
+ loss_agg = loss
202
+ else:
203
+ if aggregate == 'mean':
204
+ loss_agg += loss
205
+
206
+ elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
207
+ loss_agg = loss_agg.clamp(max=loss)
208
+
209
+ elif aggregate == 'max':
210
+ loss_agg = loss_agg.clamp(min=loss)
211
+
212
+ # add grad
213
+ if backward:
214
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
215
+ if grad_agg is None:
216
+ grad_agg = grad
217
+ else:
218
+ if aggregate == 'mean':
219
+ torch._foreach_add_(grad_agg, grad)
220
+
221
+ elif aggregate == 'min':
222
+ grad_agg_abs = torch._foreach_abs(grad_agg)
223
+ torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
224
+ grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
225
+
226
+ elif aggregate == 'max':
227
+ grad_agg_abs = torch._foreach_abs(grad_agg)
228
+ torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
229
+ grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
230
+
231
+ elif aggregate == 'min-norm':
232
+ if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
233
+ grad_agg = grad
234
+ loss_agg = loss
235
+
236
+ elif aggregate == 'min-value':
237
+ if loss < loss_agg:
238
+ grad_agg = grad
239
+ loss_agg = loss
240
+
241
+ # undo perturbation
242
+ torch._foreach_copy_(params, x_0)
243
+
244
+ # adaptive sigma
245
+ # by value
246
+ if sigma_strategy == 'value':
247
+ if f_0 is None:
248
+ with torch.enable_grad() if backward else nullcontext():
249
+ f_0 = closure(False)
250
+
251
+ if loss < f_0:
252
+ n_good += 1
253
+
254
+ # by gradient norm
255
+ elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
256
+ assert grad is not None
257
+ if g_0 is None:
258
+ with torch.enable_grad() if backward else nullcontext():
259
+ closure()
260
+ g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
261
+
262
+ if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
263
+ n_good += 1
264
+
265
+ # update sigma if strategy is enabled
266
+ if sigma_strategy is not None:
267
+
268
+ sigma_target = self.defaults['sigma_target']
269
+ if isinstance(sigma_target, float):
270
+ sigma_target = int(max(1, n_finite * sigma_target))
271
+
272
+ if n_good >= sigma_target:
273
+ key = 'sigma_nplus'
274
+ else:
275
+ key = 'sigma_nminus'
276
+
277
+ for p in params:
278
+ self.state[p]['sigma'] *= self.settings[p][key]
279
+
280
+ # if no finite losses, just return inf
281
+ if n_finite == 0:
282
+ assert loss_agg is None and grad_agg is None
283
+ loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
284
+ grad = [torch.full_like(p, torch.inf) for p in params]
285
+ return loss, grad
286
+
287
+ assert loss_agg is not None
288
+
289
+ # no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
290
+ if aggregate != 'mean':
291
+ return loss_agg, grad_agg
292
+
293
+ # on mean divide by number of evals
294
+ loss_agg /= n_finite
295
+
296
+ if backward:
297
+ assert grad_agg is not None
298
+ torch._foreach_div_(grad_agg, n_finite)
299
+
300
+ return loss_agg, grad_agg
@@ -1,2 +1,2 @@
1
1
  from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
2
- from .adaptive import PolyakStepSize, BarzilaiBorwein
2
+ from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD