torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,124 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import torch
7
+ import pybobyqa
8
+
9
+ from ...utils import TensorList
10
+ from .wrapper import WrapperBase
11
+
12
+ Closure = Callable[[bool], Any]
13
+
14
+
15
+ class PyBobyqaWrapper(WrapperBase):
16
+ """Use Py-BOBYQA is PyTorch optimizer.
17
+
18
+ Note that this performs full minimization on each step,
19
+ so usually you would want to perform a single step, although performing multiple steps will refine the
20
+ solution.
21
+
22
+ See https://numericalalgorithmsgroup.github.io/pybobyqa/build/html/userguide.html for detailed descriptions of arguments.
23
+
24
+ Args:
25
+ params (Iterable): iterable of parameters to optimize or dicts defining parameter groups.
26
+ lb (float | None, optional): optional lower bounds. Defaults to None.
27
+ ub (float | None, optional): optional upper bounds. Defaults to None.
28
+ projections (list[Callable] | None, optional):
29
+ a list of functions defining the Euclidean projections for each general convex constraint C_i.
30
+ Each element of the list projections is a function that takes an input vector x (numpy array)
31
+ and returns the closest point to that is in C_i. Defaults to None.
32
+ npt (int | None, optional): the number of interpolation points to use. Defaults to None.
33
+ rhobeg (float | None, optional):
34
+ the initial value of the trust region radius. Defaults to None.
35
+ rhoend (float | None, optional):
36
+ minimum allowed value of trust region radius, which determines when a successful
37
+ termination occurs. Defaults to 1e-8.
38
+ maxfun (int | None, optional):
39
+ the maximum number of objective evaluations the algorithm may request,
40
+ default is min(100(n+1), 1000). Defaults to None.
41
+ nsamples (Callable | None, optional):
42
+ a Python function nsamples(delta, rho, iter, nrestarts)
43
+ which returns the number of times to evaluate objfun at a given point.
44
+ This is only applicable for objectives with stochastic noise,
45
+ when averaging multiple evaluations at the same point produces a more accurate value.
46
+ The input parameters are the trust region radius (delta),
47
+ the lower bound on the trust region radius (rho),
48
+ how many iterations the algorithm has been running for (iter),
49
+ and how many restarts have been performed (nrestarts).
50
+ Default is no averaging (i.e. nsamples(delta, rho, iter, nrestarts)=1).
51
+ Defaults to None.
52
+ user_params (dict | None, optional):
53
+ dictionary of advanced parameters,
54
+ see https://numericalalgorithmsgroup.github.io/pybobyqa/build/html/advanced.html).
55
+ Defaults to None.
56
+ objfun_has_noise (bool, optional):
57
+ a flag to indicate whether or not objfun has stochastic noise;
58
+ i.e. will calling objfun(x) multiple times at the same value of x give different results?
59
+ This is used to set some sensible default parameters (including using multiple restarts),
60
+ all of which can be overridden by the values provided in user_params. Defaults to False.
61
+ seek_global_minimum (bool, optional):
62
+ a flag to indicate whether to search for a global minimum, rather than a local minimum.
63
+ This is used to set some sensible default parameters,
64
+ all of which can be overridden by the values provided in user_params.
65
+ If True, both upper and lower bounds must be set.
66
+ Note that Py-BOBYQA only implements a heuristic method,
67
+ so there are no guarantees it will find a global minimum.
68
+ However, by using this flag, it is more likely to escape local minima
69
+ if there are better values nearby. The method used is a multiple restart mechanism,
70
+ where we repeatedly re-initialize Py-BOBYQA from the best point found so far,
71
+ but where we use a larger trust reigon radius each time
72
+ (note: this is different to more common multi-start approach to global optimization).
73
+ Defaults to False.
74
+ scaling_within_bounds (bool, optional):
75
+ a flag to indicate whether the algorithm should internally shift and scale the entries of x
76
+ so that the bounds become 0 <= x <= 1. This is useful is you are setting bounds and the
77
+ bounds have different orders of magnitude. If scaling_within_bounds=True,
78
+ the values of rhobeg and rhoend apply to the shifted variables. Defaults to False.
79
+ do_logging (bool, optional):
80
+ a flag to indicate whether logging output should be produced.
81
+ This is not automatically visible unless you use the Python logging module. Defaults to True.
82
+ print_progress (bool, optional):
83
+ a flag to indicate whether to print a per-iteration progress log to terminal. Defaults to False.
84
+ """
85
+ def __init__(
86
+ self,
87
+ params,
88
+ lb: float | None = None,
89
+ ub: float | None = None,
90
+ projections = None,
91
+ npt: int | None = None,
92
+ rhobeg: float | None = None,
93
+ rhoend: float = 1e-8,
94
+ maxfun: int | None = None,
95
+ nsamples: Callable | None | None = None,
96
+ user_params: dict[str, Any] | None = None,
97
+ objfun_has_noise: bool = False,
98
+ seek_global_minimum: bool = False,
99
+ scaling_within_bounds: bool = False,
100
+ do_logging: bool = True,
101
+ print_progress: bool = False,
102
+ ):
103
+ super().__init__(params, dict(lb=lb, ub=ub))
104
+ kwargs = locals().copy()
105
+ for k in ["self", "__class__", "params", "lb", "ub"]:
106
+ del kwargs[k]
107
+ self._kwargs = kwargs
108
+
109
+ @torch.no_grad
110
+ def step(self, closure: Closure):
111
+ params = TensorList(self._get_params())
112
+ x0 = params.to_vec().numpy(force=True)
113
+ bounds = self._get_bounds()
114
+
115
+ soln: pybobyqa.solver.OptimResults = pybobyqa.solve(
116
+ objfun=partial(self._f, closure=closure, params=params),
117
+ x0=x0,
118
+ bounds=bounds,
119
+ **self._kwargs
120
+ )
121
+
122
+ params.from_vec_(torch.as_tensor(soln.x, device = params[0].device, dtype=params[0].dtype,))
123
+ return soln.f
124
+
@@ -0,0 +1,7 @@
1
+ from .basin_hopping import ScipyBasinHopping
2
+ from .brute import ScipyBrute
3
+ from .differential_evolution import ScipyDE
4
+ from .direct import ScipyDIRECT
5
+ from .dual_annealing import ScipyDualAnnealing
6
+ from .minimize import ScipyMinimize
7
+ from .sgho import ScipySHGO
@@ -0,0 +1,117 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+ from .minimize import _use_jac_hess_hessp
12
+
13
+ Closure = Callable[[bool], Any]
14
+
15
+
16
+ class ScipyBasinHopping(WrapperBase):
17
+ def __init__(
18
+ self,
19
+ params,
20
+ niter: int = 100,
21
+ T: float = 1,
22
+ stepsize: float = 0.5,
23
+ minimizer_kwargs: dict | None = None,
24
+ take_step: Callable | None = None,
25
+ accept_test: Callable | None = None,
26
+ callback: Callable | None = None,
27
+ interval: int = 50,
28
+ disp: bool = False,
29
+ niter_success: int | None = None,
30
+ rng: int | np.random.Generator | None = None,
31
+ lb:float | None = None,
32
+ ub:float | None = None,
33
+ method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
34
+ 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
35
+ 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
36
+ 'trust-krylov'] | str | None = None,
37
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
38
+ hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
39
+ use_hessp: bool = True,
40
+
41
+ *,
42
+ target_accept_rate: float = 0.5,
43
+ stepwise_factor: float = 0.9
44
+ ):
45
+ super().__init__(params, dict(lb=lb, ub=ub))
46
+
47
+ kwargs = locals().copy()
48
+ del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs["minimizer_kwargs"]
49
+ del kwargs['method'], kwargs["jac"], kwargs['hess'], kwargs['use_hessp']
50
+ del kwargs["lb"], kwargs["ub"]
51
+ self._kwargs = kwargs
52
+
53
+ self._minimizer_kwargs = minimizer_kwargs
54
+ self.method = method
55
+ self.hess = hess
56
+ self.jac, self.use_jac_autograd, self.use_hess_autograd, self.use_hessp = _use_jac_hess_hessp(method, jac, hess, use_hessp)
57
+
58
+ def _jac(self, x: np.ndarray, params: list[torch.Tensor], closure):
59
+ f,g = self._f_g(x, params, closure)
60
+ return g
61
+
62
+ def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
63
+ if self.use_jac_autograd:
64
+ f, g = self._f_g(x, params, closure)
65
+ if self.method is not None and self.method.lower() == 'slsqp': g = g.astype(np.float64) # slsqp requires float64
66
+ return f, g
67
+
68
+ return self._f(x, params, closure)
69
+
70
+ def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
71
+ f,g,H = self._f_g_H(x, params, closure)
72
+ return H
73
+
74
+ def _hessp(self, x: np.ndarray, p:np.ndarray, params: list[torch.Tensor], closure):
75
+ f,g,Hvp = self._f_g_Hvp(x, p, params, closure)
76
+ return Hvp
77
+
78
+ @torch.no_grad
79
+ def step(self, closure: Closure):
80
+ params = TensorList(self._get_params())
81
+ x0 = params.to_vec().numpy(force=True)
82
+ bounds = self._get_bounds()
83
+
84
+ # determine hess argument
85
+ hess = self.hess
86
+ hessp = None
87
+ if hess == 'autograd':
88
+ if self.use_hess_autograd:
89
+ if self.use_hessp:
90
+ hessp = partial(self._hessp, params=params, closure=closure)
91
+ hess = None
92
+ else:
93
+ hess = partial(self._hess, params=params, closure=closure)
94
+ # hess = 'autograd' but method doesn't use hess
95
+ else:
96
+ hess = None
97
+
98
+
99
+ if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
100
+ x0 = x0.astype(np.float64) # those methods error without this
101
+
102
+ minimizer_kwargs = self._minimizer_kwargs.copy() if self._minimizer_kwargs is not None else {}
103
+ minimizer_kwargs.setdefault("method", self.method)
104
+ minimizer_kwargs.setdefault("jac", self.jac)
105
+ minimizer_kwargs.setdefault("hess", hess)
106
+ minimizer_kwargs.setdefault("hessp", hessp)
107
+ minimizer_kwargs.setdefault("bounds", bounds)
108
+
109
+ res = scipy.optimize.basinhopping(
110
+ partial(self._objective, params = params, closure = closure),
111
+ x0 = params.to_vec().numpy(force=True),
112
+ minimizer_kwargs=minimizer_kwargs,
113
+ **self._kwargs
114
+ )
115
+
116
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
117
+ return res.fun
@@ -0,0 +1,48 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+
12
+ Closure = Callable[[bool], Any]
13
+
14
+
15
+
16
+ class ScipyBrute(WrapperBase):
17
+ def __init__(
18
+ self,
19
+ params,
20
+ lb: float,
21
+ ub: float,
22
+ Ns: int = 20,
23
+ finish = scipy.optimize.fmin,
24
+ disp: bool = False,
25
+ workers: int = 1
26
+ ):
27
+ super().__init__(params, dict(lb=lb, ub=ub))
28
+
29
+ kwargs = locals().copy()
30
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
31
+ self._kwargs = kwargs
32
+
33
+ @torch.no_grad
34
+ def step(self, closure: Closure):
35
+ params = TensorList(self._get_params())
36
+ bounds = self._get_bounds()
37
+ assert bounds is not None
38
+
39
+ res,fval,grid,Jout = scipy.optimize.brute(
40
+ partial(self._f, params = params, closure = closure),
41
+ ranges=bounds,
42
+ full_output=True,
43
+ **self._kwargs
44
+ )
45
+
46
+ params.from_vec_(torch.as_tensor(res, device = params[0].device, dtype=params[0].dtype))
47
+
48
+ return fval
@@ -0,0 +1,80 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+
12
+ Closure = Callable[[bool], Any]
13
+
14
+
15
+
16
+
17
+
18
+ class ScipyDE(WrapperBase):
19
+ """Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
20
+ so usually you would want to perform a single step. This also requires bounds to be specified.
21
+
22
+ Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
23
+ for all other args.
24
+
25
+ Args:
26
+ params: iterable of parameters to optimize or dicts defining parameter groups.
27
+ bounds (tuple[float,float], optional): tuple with lower and upper bounds.
28
+ DE requires bounds to be specified. Defaults to None.
29
+
30
+ other args:
31
+ refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
32
+ """
33
+ def __init__(
34
+ self,
35
+ params,
36
+ lb: float,
37
+ ub: float,
38
+ strategy: Literal['best1bin', 'best1exp', 'rand1bin', 'rand1exp', 'rand2bin', 'rand2exp',
39
+ 'randtobest1bin', 'randtobest1exp', 'currenttobest1bin', 'currenttobest1exp',
40
+ 'best2exp', 'best2bin'] = 'best1bin',
41
+ maxiter: int = 1000,
42
+ popsize: int = 15,
43
+ tol: float = 0.01,
44
+ mutation = (0.5, 1),
45
+ recombination: float = 0.7,
46
+ seed = None,
47
+ callback = None,
48
+ disp: bool = False,
49
+ polish: bool = True,
50
+ init: str = 'latinhypercube',
51
+ atol: int = 0,
52
+ updating: str = 'immediate',
53
+ workers: int = 1,
54
+ constraints = (),
55
+ *,
56
+ integrality = None,
57
+
58
+ ):
59
+ super().__init__(params, dict(lb=lb, ub=ub))
60
+
61
+ kwargs = locals().copy()
62
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
63
+ self._kwargs = kwargs
64
+
65
+ @torch.no_grad
66
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
67
+ params = TensorList(self._get_params())
68
+ x0 = params.to_vec().numpy(force=True)
69
+ bounds = self._get_bounds()
70
+ assert bounds is not None
71
+
72
+ res = scipy.optimize.differential_evolution(
73
+ partial(self._f, params = params, closure = closure),
74
+ x0 = x0,
75
+ bounds=bounds,
76
+ **self._kwargs
77
+ )
78
+
79
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
80
+ return res.fun
@@ -0,0 +1,69 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+
12
+ Closure = Callable[[bool], Any]
13
+
14
+
15
+
16
+
17
+ class ScipyDIRECT(WrapperBase):
18
+ def __init__(
19
+ self,
20
+ params,
21
+ lb: float,
22
+ ub: float,
23
+ maxfun: int | None = 1000,
24
+ maxiter: int = 1000,
25
+ eps: float = 0.0001,
26
+ locally_biased: bool = True,
27
+ f_min: float = -np.inf,
28
+ f_min_rtol: float = 0.0001,
29
+ vol_tol: float = 1e-16,
30
+ len_tol: float = 0.000001,
31
+ callback = None,
32
+ ):
33
+ super().__init__(params, dict(lb=lb, ub=ub))
34
+
35
+ kwargs = locals().copy()
36
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
37
+ self._kwargs = kwargs
38
+
39
+ def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
40
+ if self.raised: return np.inf
41
+ try:
42
+ return self._f(x, params, closure)
43
+
44
+ except Exception as e:
45
+ # this makes exceptions work in fcmaes and scipy direct
46
+ self.e = e
47
+ self.raised = True
48
+ return np.inf
49
+
50
+ @torch.no_grad
51
+ def step(self, closure: Closure):
52
+ self.raised = False
53
+ self.e = None
54
+
55
+ params = TensorList(self._get_params())
56
+ bounds = self._get_bounds()
57
+ assert bounds is not None
58
+
59
+ res = scipy.optimize.direct(
60
+ partial(self._objective, params=params, closure=closure),
61
+ bounds=bounds,
62
+ **self._kwargs
63
+ )
64
+
65
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
66
+
67
+ if self.e is not None: raise self.e from None
68
+ return res.fun
69
+
@@ -0,0 +1,115 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+ from .minimize import _use_jac_hess_hessp
12
+
13
+ Closure = Callable[[bool], Any]
14
+
15
+
16
+
17
+
18
+ class ScipyDualAnnealing(WrapperBase):
19
+ def __init__(
20
+ self,
21
+ params,
22
+ lb: float,
23
+ ub: float,
24
+ maxiter=1000,
25
+ minimizer_kwargs=None,
26
+ initial_temp=5230.0,
27
+ restart_temp_ratio=2.0e-5,
28
+ visit=2.62,
29
+ accept=-5.0,
30
+ maxfun=1e7,
31
+ rng=None,
32
+ no_local_search=False,
33
+ method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
34
+ 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
35
+ 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
36
+ 'trust-krylov'] | str = 'l-bfgs-b',
37
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
38
+ hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
39
+ use_hessp: bool = True,
40
+ ):
41
+ super().__init__(params, dict(lb=lb, ub=ub))
42
+
43
+ kwargs = locals().copy()
44
+ for k in ["self", "params", "lb", "ub", "__class__", "method", "jac", "hess", "use_hessp", "minimizer_kwargs"]:
45
+ del kwargs[k]
46
+ self._kwargs = kwargs
47
+
48
+ self._minimizer_kwargs = minimizer_kwargs
49
+ self.method = method
50
+ self.hess = hess
51
+
52
+ self.jac, self.use_jac_autograd, self.use_hess_autograd, self.use_hessp = _use_jac_hess_hessp(method, jac, hess, use_hessp)
53
+
54
+ def _jac(self, x: np.ndarray, params: list[torch.Tensor], closure):
55
+ f,g = self._f_g(x, params, closure)
56
+ return g
57
+
58
+ def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
59
+ # dual annealing doesn't support this
60
+ # if self.use_jac_autograd:
61
+ # f, g = self._f_g(x, params, closure)
62
+ # if self.method.lower() == 'slsqp': g = g.astype(np.float64) # slsqp requires float64
63
+ # return f, g
64
+
65
+ return self._f(x, params, closure)
66
+
67
+ def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
68
+ f,g,H = self._f_g_H(x, params, closure)
69
+ return H
70
+
71
+ def _hessp(self, x: np.ndarray, p:np.ndarray, params: list[torch.Tensor], closure):
72
+ f,g,Hvp = self._f_g_Hvp(x, p, params, closure)
73
+ return Hvp
74
+
75
+ @torch.no_grad
76
+ def step(self, closure: Closure):
77
+ params = TensorList(self._get_params())
78
+ x0 = params.to_vec().numpy(force=True)
79
+ bounds = self._get_bounds()
80
+ assert bounds is not None
81
+
82
+ # determine hess argument
83
+ hess = self.hess
84
+ hessp = None
85
+ if hess == 'autograd':
86
+ if self.use_hess_autograd:
87
+ if self.use_hessp:
88
+ hessp = partial(self._hessp, params=params, closure=closure)
89
+ hess = None
90
+ else:
91
+ hess = partial(self._hess, params=params, closure=closure)
92
+ # hess = 'autograd' but method doesn't use hess
93
+ else:
94
+ hess = None
95
+
96
+ if self.method.lower() in ('tnc', 'slsqp'):
97
+ x0 = x0.astype(np.float64) # those methods error without this
98
+
99
+ minimizer_kwargs = self._minimizer_kwargs.copy() if self._minimizer_kwargs is not None else {}
100
+ minimizer_kwargs.setdefault("method", self.method)
101
+ minimizer_kwargs.setdefault("jac", partial(self._jac, params = params, closure = closure))
102
+ minimizer_kwargs.setdefault("hess", hess)
103
+ minimizer_kwargs.setdefault("hessp", hessp)
104
+ minimizer_kwargs.setdefault("bounds", bounds)
105
+
106
+ res = scipy.optimize.dual_annealing(
107
+ partial(self._f, params = params, closure = closure),
108
+ x0 = x0,
109
+ bounds=bounds,
110
+ minimizer_kwargs=minimizer_kwargs,
111
+ **self._kwargs,
112
+ )
113
+
114
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
115
+ return res.fun