torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,572 +0,0 @@
1
- from collections import abc
2
- from collections.abc import Callable
3
- from functools import partial
4
- from typing import Any, Literal
5
-
6
- import numpy as np
7
- import torch
8
-
9
- import scipy.optimize
10
-
11
- from ...utils import Optimizer, TensorList
12
- from ...utils.derivatives import (
13
- flatten_jacobian,
14
- jacobian_and_hessian_mat_wrt,
15
- jacobian_wrt,
16
- )
17
-
18
-
19
- def _ensure_float(x) -> float:
20
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
21
- if isinstance(x, np.ndarray): return float(x.item())
22
- return float(x)
23
-
24
- def _ensure_numpy(x):
25
- if isinstance(x, torch.Tensor): return x.detach().cpu()
26
- if isinstance(x, np.ndarray): return x
27
- return np.array(x)
28
-
29
- Closure = Callable[[bool], Any]
30
-
31
- class ScipyMinimize(Optimizer):
32
- """Use scipy.minimize.optimize as pytorch optimizer. Note that this performs full minimization on each step,
33
- so usually you would want to perform a single step, although performing multiple steps will refine the
34
- solution.
35
-
36
- Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
37
- for a detailed description of args.
38
-
39
- Args:
40
- params: iterable of parameters to optimize or dicts defining parameter groups.
41
- method (str | None, optional): type of solver.
42
- If None, scipy will select one of BFGS, L-BFGS-B, SLSQP,
43
- depending on whether or not the problem has constraints or bounds.
44
- Defaults to None.
45
- bounds (optional): bounds on variables. Defaults to None.
46
- constraints (tuple, optional): constraints definition. Defaults to ().
47
- tol (float | None, optional): Tolerance for termination. Defaults to None.
48
- callback (Callable | None, optional): A callable called after each iteration. Defaults to None.
49
- options (dict | None, optional): A dictionary of solver options. Defaults to None.
50
- jac (str, optional): Method for computing the gradient vector.
51
- Only for CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
52
- In addition to scipy options, this supports 'autograd', which uses pytorch autograd.
53
- This setting is ignored for methods that don't require gradient. Defaults to 'autograd'.
54
- hess (str, optional):
55
- Method for computing the Hessian matrix.
56
- Only for Newton-CG, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
57
- This setting is ignored for methods that don't require hessian. Defaults to 'autograd'.
58
- tikhonov (float, optional):
59
- optional hessian regularizer value. Only has effect for methods that require hessian.
60
- """
61
- def __init__(
62
- self,
63
- params,
64
- method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
65
- 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
66
- 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
67
- 'trust-krylov'] | str | None = None,
68
- lb = None,
69
- ub = None,
70
- constraints = (),
71
- tol: float | None = None,
72
- callback = None,
73
- options = None,
74
- jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
75
- hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
76
- ):
77
- defaults = dict(lb=lb, ub=ub)
78
- super().__init__(params, defaults)
79
- self.method = method
80
- self.constraints = constraints
81
- self.tol = tol
82
- self.callback = callback
83
- self.options = options
84
-
85
- self.jac = jac
86
- self.hess = hess
87
-
88
- self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
89
- 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
90
- 'trust-ncg', 'trust-krylov', 'trust-exact', 'trust-constr',
91
- ])
92
- self.use_hess_autograd = isinstance(hess, str) and hess.lower() == 'autograd' and method is not None and method.lower() in [
93
- 'newton-cg', 'dogleg', 'trust-ncg', 'trust-krylov', 'trust-exact'
94
- ]
95
-
96
- # jac in scipy is '2-point', '3-point', 'cs', True or None.
97
- if self.jac == 'autograd':
98
- if self.use_jac_autograd: self.jac = True
99
- else: self.jac = None
100
-
101
-
102
- def _hess(self, x: np.ndarray, params: TensorList, closure):
103
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
104
- with torch.enable_grad():
105
- value = closure(False)
106
- _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
107
- return H.numpy(force=True)
108
-
109
- def _objective(self, x: np.ndarray, params: TensorList, closure):
110
- # set params to x
111
- params.from_vec_(torch.from_numpy(x).to(params[0], copy=False))
112
-
113
- # return value and maybe gradients
114
- if self.use_jac_autograd:
115
- with torch.enable_grad(): value = _ensure_float(closure())
116
- grad = params.ensure_grad_().grad.to_vec().numpy(force=True)
117
- # slsqp requires float64
118
- if self.method.lower() == 'slsqp': grad = grad.astype(np.float64)
119
- return value, grad
120
- return _ensure_float(closure(False))
121
-
122
- @torch.no_grad
123
- def step(self, closure: Closure):# pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
124
- params = self.get_params()
125
-
126
- # determine hess argument
127
- if self.hess == 'autograd':
128
- if self.use_hess_autograd: hess = partial(self._hess, params = params, closure = closure)
129
- else: hess = None
130
- else: hess = self.hess
131
-
132
- x0 = params.to_vec().numpy(force=True)
133
-
134
- # make bounds
135
- lb, ub = self.group_vals('lb', 'ub', cls=list)
136
- bounds = None
137
- if any(b is not None for b in lb) or any(b is not None for b in ub):
138
- bounds = []
139
- for p, l, u in zip(params, lb, ub):
140
- bounds.extend([(l, u)] * p.numel())
141
-
142
- if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
143
- x0 = x0.astype(np.float64) # those methods error without this
144
-
145
- res = scipy.optimize.minimize(
146
- partial(self._objective, params = params, closure = closure),
147
- x0 = x0,
148
- method=self.method,
149
- bounds=bounds,
150
- constraints=self.constraints,
151
- tol=self.tol,
152
- callback=self.callback,
153
- options=self.options,
154
- jac = self.jac,
155
- hess = hess,
156
- )
157
-
158
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
159
- return res.fun
160
-
161
-
162
-
163
- class ScipyRootOptimization(Optimizer):
164
- """Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
165
-
166
- Args:
167
- params: iterable of parameters to optimize or dicts defining parameter groups.
168
- method (str | None, optional): _description_. Defaults to None.
169
- tol (float | None, optional): _description_. Defaults to None.
170
- callback (_type_, optional): _description_. Defaults to None.
171
- options (_type_, optional): _description_. Defaults to None.
172
- jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
173
- """
174
- def __init__(
175
- self,
176
- params,
177
- method: Literal[
178
- "hybr",
179
- "lm",
180
- "broyden1",
181
- "broyden2",
182
- "anderson",
183
- "linearmixing",
184
- "diagbroyden",
185
- "excitingmixing",
186
- "krylov",
187
- "df-sane",
188
- ] = 'hybr',
189
- tol: float | None = None,
190
- callback = None,
191
- options = None,
192
- jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
193
- ):
194
- super().__init__(params, {})
195
- self.method = method
196
- self.tol = tol
197
- self.callback = callback
198
- self.options = options
199
-
200
- self.jac = jac
201
- if self.jac == 'autograd': self.jac = True
202
-
203
- # those don't require jacobian
204
- if self.method.lower() in ('broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing', 'krylov', 'df-sane'):
205
- self.jac = None
206
-
207
- def _objective(self, x: np.ndarray, params: TensorList, closure):
208
- # set params to x
209
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
210
-
211
- # return gradients and maybe hessian
212
- if self.jac:
213
- with torch.enable_grad():
214
- self.value = closure(False)
215
- if not isinstance(self.value, torch.Tensor):
216
- raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(self.value)}")
217
- g, H = jacobian_and_hessian_mat_wrt([self.value], wrt=params)
218
- return g.detach().cpu().numpy(), H.detach().cpu().numpy()
219
-
220
- # return the gradients
221
- with torch.enable_grad(): self.value = closure()
222
- jac = params.ensure_grad_().grad.to_vec()
223
- return jac.detach().cpu().numpy()
224
-
225
- @torch.no_grad
226
- def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
227
- params = self.get_params()
228
-
229
- x0 = params.to_vec().detach().cpu().numpy()
230
-
231
- res = scipy.optimize.root(
232
- partial(self._objective, params = params, closure = closure),
233
- x0 = x0,
234
- method=self.method,
235
- tol=self.tol,
236
- callback=self.callback,
237
- options=self.options,
238
- jac = self.jac,
239
- )
240
-
241
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
242
- return res.fun
243
-
244
-
245
- class ScipyLeastSquaresOptimization(Optimizer):
246
- """Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
247
-
248
- Args:
249
- params: iterable of parameters to optimize or dicts defining parameter groups.
250
- method (str | None, optional): _description_. Defaults to None.
251
- tol (float | None, optional): _description_. Defaults to None.
252
- callback (_type_, optional): _description_. Defaults to None.
253
- options (_type_, optional): _description_. Defaults to None.
254
- jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
255
- """
256
- def __init__(
257
- self,
258
- params,
259
- method='trf',
260
- jac='autograd',
261
- bounds=(-np.inf, np.inf),
262
- ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
263
- f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
264
- jac_sparsity=None, max_nfev=None, verbose=0
265
- ):
266
- super().__init__(params, {})
267
- kwargs = locals().copy()
268
- del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
269
- self._kwargs = kwargs
270
-
271
- self.jac = jac
272
-
273
-
274
- def _objective(self, x: np.ndarray, params: TensorList, closure):
275
- # set params to x
276
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
277
-
278
- # return the gradients
279
- with torch.enable_grad(): self.value = closure()
280
- jac = params.ensure_grad_().grad.to_vec()
281
- return jac.numpy(force=True)
282
-
283
- def _hess(self, x: np.ndarray, params: TensorList, closure):
284
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
285
- with torch.enable_grad():
286
- value = closure(False)
287
- _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
288
- return H.numpy(force=True)
289
-
290
- @torch.no_grad
291
- def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
292
- params = self.get_params()
293
-
294
- x0 = params.to_vec().detach().cpu().numpy()
295
-
296
- if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
297
- else: jac = self.jac
298
-
299
- res = scipy.optimize.least_squares(
300
- partial(self._objective, params = params, closure = closure),
301
- x0 = x0,
302
- jac=jac, # type:ignore
303
- **self._kwargs
304
- )
305
-
306
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
307
- return res.fun
308
-
309
-
310
-
311
-
312
- class ScipyDE(Optimizer):
313
- """Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
314
- so usually you would want to perform a single step. This also requires bounds to be specified.
315
-
316
- Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
317
- for all other args.
318
-
319
- Args:
320
- params: iterable of parameters to optimize or dicts defining parameter groups.
321
- bounds (tuple[float,float], optional): tuple with lower and upper bounds.
322
- DE requires bounds to be specified. Defaults to None.
323
-
324
- other args:
325
- refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
326
- """
327
- def __init__(
328
- self,
329
- params,
330
- lb: float,
331
- ub: float,
332
- strategy: Literal['best1bin', 'best1exp', 'rand1bin', 'rand1exp', 'rand2bin', 'rand2exp',
333
- 'randtobest1bin', 'randtobest1exp', 'currenttobest1bin', 'currenttobest1exp',
334
- 'best2exp', 'best2bin'] = 'best1bin',
335
- maxiter: int = 1000,
336
- popsize: int = 15,
337
- tol: float = 0.01,
338
- mutation = (0.5, 1),
339
- recombination: float = 0.7,
340
- seed = None,
341
- callback = None,
342
- disp: bool = False,
343
- polish: bool = False,
344
- init: str = 'latinhypercube',
345
- atol: int = 0,
346
- updating: str = 'immediate',
347
- workers: int = 1,
348
- constraints = (),
349
- *,
350
- integrality = None,
351
-
352
- ):
353
- super().__init__(params, lb=lb, ub=ub)
354
-
355
- kwargs = locals().copy()
356
- del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
357
- self._kwargs = kwargs
358
-
359
- def _objective(self, x: np.ndarray, params: TensorList, closure):
360
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
361
- return _ensure_float(closure(False))
362
-
363
- @torch.no_grad
364
- def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
365
- params = self.get_params()
366
-
367
- x0 = params.to_vec().detach().cpu().numpy()
368
-
369
- lb, ub = self.group_vals('lb', 'ub', cls=list)
370
- bounds = []
371
- for p, l, u in zip(params, lb, ub):
372
- bounds.extend([(l, u)] * p.numel())
373
-
374
- res = scipy.optimize.differential_evolution(
375
- partial(self._objective, params = params, closure = closure),
376
- x0 = x0,
377
- bounds=bounds,
378
- **self._kwargs
379
- )
380
-
381
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
382
- return res.fun
383
-
384
-
385
-
386
- class ScipyDualAnnealing(Optimizer):
387
- def __init__(
388
- self,
389
- params,
390
- lb: float,
391
- ub: float,
392
- maxiter=1000,
393
- minimizer_kwargs=None,
394
- initial_temp=5230.0,
395
- restart_temp_ratio=2.0e-5,
396
- visit=2.62,
397
- accept=-5.0,
398
- maxfun=1e7,
399
- rng=None,
400
- no_local_search=False,
401
- ):
402
- super().__init__(params, lb=lb, ub=ub)
403
-
404
- kwargs = locals().copy()
405
- del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
406
- self._kwargs = kwargs
407
-
408
- def _objective(self, x: np.ndarray, params: TensorList, closure):
409
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
410
- return _ensure_float(closure(False))
411
-
412
- @torch.no_grad
413
- def step(self, closure: Closure):
414
- params = self.get_params()
415
-
416
- x0 = params.to_vec().detach().cpu().numpy()
417
- lb, ub = self.group_vals('lb', 'ub', cls=list)
418
- bounds = []
419
- for p, l, u in zip(params, lb, ub):
420
- bounds.extend([(l, u)] * p.numel())
421
-
422
- res = scipy.optimize.dual_annealing(
423
- partial(self._objective, params = params, closure = closure),
424
- x0 = x0,
425
- bounds=bounds,
426
- **self._kwargs
427
- )
428
-
429
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
430
- return res.fun
431
-
432
-
433
-
434
- class ScipySHGO(Optimizer):
435
- def __init__(
436
- self,
437
- params,
438
- lb: float,
439
- ub: float,
440
- constraints = None,
441
- n: int = 100,
442
- iters: int = 1,
443
- callback = None,
444
- minimizer_kwargs = None,
445
- options = None,
446
- sampling_method: str = 'simplicial',
447
- ):
448
- super().__init__(params, lb=lb, ub=ub)
449
-
450
- kwargs = locals().copy()
451
- del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
452
- self._kwargs = kwargs
453
-
454
- def _objective(self, x: np.ndarray, params: TensorList, closure):
455
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
456
- return _ensure_float(closure(False))
457
-
458
- @torch.no_grad
459
- def step(self, closure: Closure):
460
- params = self.get_params()
461
-
462
- lb, ub = self.group_vals('lb', 'ub', cls=list)
463
- bounds = []
464
- for p, l, u in zip(params, lb, ub):
465
- bounds.extend([(l, u)] * p.numel())
466
-
467
- res = scipy.optimize.shgo(
468
- partial(self._objective, params = params, closure = closure),
469
- bounds=bounds,
470
- **self._kwargs
471
- )
472
-
473
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
474
- return res.fun
475
-
476
-
477
- class ScipyDIRECT(Optimizer):
478
- def __init__(
479
- self,
480
- params,
481
- lb: float,
482
- ub: float,
483
- maxfun: int | None = 1000,
484
- maxiter: int = 1000,
485
- eps: float = 0.0001,
486
- locally_biased: bool = True,
487
- f_min: float = -np.inf,
488
- f_min_rtol: float = 0.0001,
489
- vol_tol: float = 1e-16,
490
- len_tol: float = 0.000001,
491
- callback = None,
492
- ):
493
- super().__init__(params, lb=lb, ub=ub)
494
-
495
- kwargs = locals().copy()
496
- del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
497
- self._kwargs = kwargs
498
-
499
- def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
500
- if self.raised: return np.inf
501
- try:
502
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
503
- return _ensure_float(closure(False))
504
- except Exception as e:
505
- # he he he ha, I found a way to make exceptions work in fcmaes and scipy direct
506
- self.e = e
507
- self.raised = True
508
- return np.inf
509
-
510
- @torch.no_grad
511
- def step(self, closure: Closure):
512
- self.raised = False
513
- self.e = None
514
-
515
- params = self.get_params()
516
-
517
- lb, ub = self.group_vals('lb', 'ub', cls=list)
518
- bounds = []
519
- for p, l, u in zip(params, lb, ub):
520
- bounds.extend([(l, u)] * p.numel())
521
-
522
- res = scipy.optimize.direct(
523
- partial(self._objective, params=params, closure=closure),
524
- bounds=bounds,
525
- **self._kwargs
526
- )
527
-
528
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
529
-
530
- if self.e is not None: raise self.e from None
531
- return res.fun
532
-
533
-
534
-
535
-
536
- class ScipyBrute(Optimizer):
537
- def __init__(
538
- self,
539
- params,
540
- lb: float,
541
- ub: float,
542
- Ns: int = 20,
543
- full_output: int = 0,
544
- finish = scipy.optimize.fmin,
545
- disp: bool = False,
546
- workers: int = 1
547
- ):
548
- super().__init__(params, lb=lb, ub=ub)
549
-
550
- kwargs = locals().copy()
551
- del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
552
- self._kwargs = kwargs
553
-
554
- def _objective(self, x: np.ndarray, params: TensorList, closure):
555
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
556
- return _ensure_float(closure(False))
557
-
558
- @torch.no_grad
559
- def step(self, closure: Closure):
560
- params = self.get_params()
561
-
562
- lb, ub = self.group_vals('lb', 'ub', cls=list)
563
- bounds = []
564
- for p, l, u in zip(params, lb, ub):
565
- bounds.extend([(l, u)] * p.numel())
566
-
567
- x0 = scipy.optimize.brute(
568
- partial(self._objective, params = params, closure = closure),
569
- ranges=bounds,
570
- **self._kwargs
571
- )
572
- params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
@@ -1,12 +0,0 @@
1
- from . import linear_operator
2
- from .matrix_funcs import (
3
- eigvals_func,
4
- inv_sqrt_2x2,
5
- matrix_power_eigh,
6
- singular_vals_func,
7
- x_inv,
8
- )
9
- from .orthogonalize import gram_schmidt
10
- from .qr import qr_householder
11
- from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
12
- from .svd import randomized_svd
@@ -1,87 +0,0 @@
1
- import warnings
2
- from collections.abc import Callable
3
-
4
- import torch
5
-
6
- def eigvals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
7
- L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
8
- L = fn(L)
9
- return (Q * L.unsqueeze(-2)) @ Q.mH
10
-
11
- def singular_vals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
12
- U, S, V = torch.linalg.svd(A) # pylint:disable=not-callable
13
- S = fn(S)
14
- return (U * S.unsqueeze(-2)) @ V.mT
15
-
16
- def matrix_power_eigh(A: torch.Tensor, pow:float):
17
- L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
18
- if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
19
- return (Q * L.pow(pow).unsqueeze(-2)) @ Q.mH
20
-
21
-
22
- def inv_sqrt_2x2(A: torch.Tensor, force_pd: bool=False) -> torch.Tensor:
23
- """Inverse square root of a possibly batched 2x2 matrix using a general formula for 2x2 matrices so that this is way faster than torch linalg. I tried doing a hierarchical 2x2 preconditioning but it didn't work well."""
24
- eps = torch.finfo(A.dtype).tiny * 2
25
-
26
- a = A[..., 0, 0]
27
- b = A[..., 0, 1]
28
- c = A[..., 1, 0]
29
- d = A[..., 1, 1]
30
-
31
- det = (a * d).sub_(b * c)
32
- trace = a + d
33
-
34
- if force_pd:
35
- # add smallest eigenvalue magnitude to diagonal to force PD
36
- # could also abs or clip eigenvalues bc there is a formula for eigenvectors
37
- term1 = trace/2
38
- term2 = (trace.pow(2).div_(4).sub_(det)).clamp_(min=eps).sqrt_()
39
- y1 = term1 + term2
40
- y2 = term1 - term2
41
- smallest_eigval = torch.minimum(y1, y2).neg_().clamp_(min=0) + eps
42
- a = a+smallest_eigval
43
- d = d+smallest_eigval
44
-
45
- # recalculate det and trace witg new a and b
46
- det = (a * d).sub_(b * c)
47
- trace = a + d
48
-
49
- s = (det.clamp(min=eps)).sqrt_()
50
-
51
- tau_squared = trace + 2 * s
52
- tau = (tau_squared.clamp(min=eps)).sqrt_()
53
-
54
- denom = s * tau
55
-
56
- coeff = (denom.clamp(min=eps)).reciprocal_().unsqueeze(-1).unsqueeze(-1)
57
-
58
- row1 = torch.stack([d + s, -b], dim=-1)
59
- row2 = torch.stack([-c, a + s], dim=-1)
60
- M = torch.stack([row1, row2], dim=-2)
61
-
62
- return coeff * M
63
-
64
-
65
- def x_inv(diag: torch.Tensor,antidiag: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
66
- """invert a matrix with diagonal and anti-diagonal non zero elements, with no checks that it is invertible"""
67
- n = diag.shape[0]
68
- if diag.dim() != 1 or antidiag.dim() != 1 or antidiag.shape[0] != n:
69
- raise ValueError("Input tensors must be 1D and have the same size.")
70
- if n == 0:
71
- return torch.empty_like(diag), torch.empty_like(antidiag)
72
-
73
- # opposite indexes
74
- diag_rev = torch.flip(diag, dims=[0])
75
- antidiag_rev = torch.flip(antidiag, dims=[0])
76
-
77
- # determinants
78
- # det_i = d[i] * d[n-1-i] - a[i] * a[n-1-i]
79
- determinant_vec = diag * diag_rev - antidiag * antidiag_rev
80
-
81
- # inverse diagonal elements: y_d[i] = d[n-1-i] / det_i
82
- inv_diag_vec = diag_rev / determinant_vec
83
-
84
- # inverse anti-diagonal elements: y_a[i] = -a[i] / det_i
85
- inv_anti_diag_vec = -antidiag / determinant_vec
86
-
87
- return inv_diag_vec, inv_anti_diag_vec
@@ -1,12 +0,0 @@
1
- from typing import overload
2
- import torch
3
- from ..tensorlist import TensorList
4
-
5
- @overload
6
- def gram_schmidt(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ...
7
- @overload
8
- def gram_schmidt(x: TensorList, y: TensorList) -> tuple[TensorList, TensorList]: ...
9
- def gram_schmidt(x, y):
10
- """makes two orthogonal vectors, only y is changed"""
11
- min = torch.finfo(x.dtype).tiny * 2
12
- return x, y - (x*y) / (x*x).clip(min=min)