torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+
12
+ Closure = Callable[[bool], Any]
13
+
14
+
15
+ class ScipyRootOptimization(WrapperBase):
16
+
17
+ """Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
18
+
19
+ Args:
20
+ params: iterable of parameters to optimize or dicts defining parameter groups.
21
+ method (str | None, optional): _description_. Defaults to None.
22
+ tol (float | None, optional): _description_. Defaults to None.
23
+ callback (_type_, optional): _description_. Defaults to None.
24
+ options (_type_, optional): _description_. Defaults to None.
25
+ jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
26
+ """
27
+ def __init__(
28
+ self,
29
+ params,
30
+ method: Literal[
31
+ "hybr",
32
+ "lm",
33
+ "broyden1",
34
+ "broyden2",
35
+ "anderson",
36
+ "linearmixing",
37
+ "diagbroyden",
38
+ "excitingmixing",
39
+ "krylov",
40
+ "df-sane",
41
+ ] = 'hybr',
42
+ tol: float | None = None,
43
+ callback = None,
44
+ options = None,
45
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
46
+ ):
47
+ super().__init__(params, {})
48
+ self.method = method
49
+ self.tol = tol
50
+ self.callback = callback
51
+ self.options = options
52
+
53
+ self.jac = jac
54
+ if self.jac == 'autograd': self.jac = True
55
+
56
+ # those don't require jacobian
57
+ if self.method.lower() in ('broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing', 'krylov', 'df-sane'):
58
+ self.jac = None
59
+
60
+ def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
61
+ if self.jac:
62
+ f, g, H = self._f_g_H(x, params, closure)
63
+ return g, H
64
+
65
+ f, g = self._f_g(x, params, closure)
66
+ return g
67
+
68
+ @torch.no_grad
69
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
70
+ params = TensorList(self._get_params())
71
+ x0 = params.to_vec().numpy(force=True)
72
+
73
+ res = scipy.optimize.root(
74
+ partial(self._objective, params = params, closure = closure),
75
+ x0 = x0,
76
+ method=self.method,
77
+ tol=self.tol,
78
+ callback=self.callback,
79
+ options=self.options,
80
+ jac = self.jac,
81
+ )
82
+
83
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
84
+ return res.fun
85
+
86
+
87
+ class ScipyLeastSquaresOptimization(WrapperBase):
88
+ """Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
89
+
90
+ Args:
91
+ params: iterable of parameters to optimize or dicts defining parameter groups.
92
+ method (str | None, optional): _description_. Defaults to None.
93
+ tol (float | None, optional): _description_. Defaults to None.
94
+ callback (_type_, optional): _description_. Defaults to None.
95
+ options (_type_, optional): _description_. Defaults to None.
96
+ jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
97
+ """
98
+ def __init__(
99
+ self,
100
+ params,
101
+ method='trf',
102
+ jac='autograd',
103
+ bounds=(-np.inf, np.inf),
104
+ ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
105
+ f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
106
+ jac_sparsity=None, max_nfev=None, verbose=0
107
+ ):
108
+ super().__init__(params, {})
109
+ kwargs = locals().copy()
110
+ del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
111
+ self._kwargs = kwargs
112
+
113
+ self.jac = jac
114
+
115
+
116
+ def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
117
+ f, g = self._f_g(x, params, closure)
118
+ return g
119
+
120
+ def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
121
+ f,g,H = self._f_g_H(x, params, closure)
122
+ return H
123
+
124
+ @torch.no_grad
125
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
126
+ params = TensorList(self._get_params())
127
+ x0 = params.to_vec().numpy(force=True)
128
+
129
+ if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
130
+ else: jac = self.jac
131
+
132
+ res = scipy.optimize.least_squares(
133
+ partial(self._objective, params = params, closure = closure),
134
+ x0 = x0,
135
+ jac=jac, # type:ignore
136
+ **self._kwargs
137
+ )
138
+
139
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
140
+ return res.fun
141
+
@@ -0,0 +1,151 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+
12
+ Closure = Callable[[bool], Any]
13
+
14
+
15
+ def _use_jac_hess_hessp(method, jac, hess, use_hessp):
16
+ # those methods can't use hessp
17
+ if (method is None) or (method.lower() not in ("newton-cg", "trust-ncg", "trust-krylov", "trust-constr")):
18
+ use_hessp = False
19
+
20
+ # those use gradients
21
+ use_jac_autograd = (jac.lower() == 'autograd') and ((method is None) or (method.lower() in [
22
+ 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
23
+ 'trust-ncg', 'trust-krylov', 'trust-exact', 'trust-constr',
24
+ ]))
25
+
26
+ # those use hessian/ some of them can use hessp instead
27
+ use_hess_autograd = (isinstance(hess, str)) and (hess.lower() == 'autograd') and (method is not None) and (method.lower() in [
28
+ 'newton-cg', 'dogleg', 'trust-ncg', 'trust-krylov', 'trust-exact'
29
+ ])
30
+
31
+ # jac in scipy is '2-point', '3-point', 'cs', True or None.
32
+ if jac == 'autograd':
33
+ if use_jac_autograd: jac = True
34
+ else: jac = None
35
+
36
+ return jac, use_jac_autograd, use_hess_autograd, use_hessp
37
+
38
+ class ScipyMinimize(WrapperBase):
39
+ """Use scipy.minimize.optimize as pytorch optimizer. Note that this performs full minimization on each step,
40
+ so usually you would want to perform a single step, although performing multiple steps will refine the
41
+ solution.
42
+
43
+ Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
44
+ for a detailed description of args.
45
+
46
+ Args:
47
+ params: iterable of parameters to optimize or dicts defining parameter groups.
48
+ method (str | None, optional): type of solver.
49
+ If None, scipy will select one of BFGS, L-BFGS-B, SLSQP,
50
+ depending on whether or not the problem has constraints or bounds.
51
+ Defaults to None.
52
+ bounds (optional): bounds on variables. Defaults to None.
53
+ constraints (tuple, optional): constraints definition. Defaults to ().
54
+ tol (float | None, optional): Tolerance for termination. Defaults to None.
55
+ callback (Callable | None, optional): A callable called after each iteration. Defaults to None.
56
+ options (dict | None, optional): A dictionary of solver options. Defaults to None.
57
+ jac (str, optional): Method for computing the gradient vector.
58
+ Only for CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
59
+ In addition to scipy options, this supports 'autograd', which uses pytorch autograd.
60
+ This setting is ignored for methods that don't require gradient. Defaults to 'autograd'.
61
+ hess (str, optional):
62
+ Method for computing the Hessian matrix.
63
+ Only for Newton-CG, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
64
+ This setting is ignored for methods that don't require hessian. Defaults to 'autograd'.
65
+ tikhonov (float, optional):
66
+ optional hessian regularizer value. Only has effect for methods that require hessian.
67
+ """
68
+ def __init__(
69
+ self,
70
+ params,
71
+ method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
72
+ 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
73
+ 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
74
+ 'trust-krylov'] | str | None = None,
75
+ lb = None,
76
+ ub = None,
77
+ constraints = (),
78
+ tol: float | None = None,
79
+ callback = None,
80
+ options = None,
81
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
82
+ hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
83
+ use_hessp: bool = True,
84
+ ):
85
+ defaults = dict(lb=lb, ub=ub)
86
+ super().__init__(params, defaults)
87
+ self.method = method
88
+ self.constraints = constraints
89
+ self.tol = tol
90
+ self.callback = callback
91
+ self.options = options
92
+ self.hess = hess
93
+
94
+ self.jac, self.use_jac_autograd, self.use_hess_autograd, self.use_hessp = _use_jac_hess_hessp(method, jac, hess, use_hessp)
95
+
96
+ def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
97
+ if self.use_jac_autograd:
98
+ f, g = self._f_g(x, params, closure)
99
+ if self.method is not None and self.method.lower() == 'slsqp': g = g.astype(np.float64) # slsqp requires float64
100
+ return f, g
101
+
102
+ return self._f(x, params, closure)
103
+
104
+ def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
105
+ f,g,H = self._f_g_H(x, params, closure)
106
+ return H
107
+
108
+ def _hessp(self, x: np.ndarray, p:np.ndarray, params: list[torch.Tensor], closure):
109
+ f,g,Hvp = self._f_g_Hvp(x, p, params, closure)
110
+ return Hvp
111
+
112
+ @torch.no_grad
113
+ def step(self, closure: Closure):# pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
114
+ params = TensorList(self._get_params())
115
+ x0 = params.to_vec().numpy(force=True)
116
+ bounds = self._get_bounds()
117
+
118
+ # determine hess argument
119
+ hess = self.hess
120
+ hessp = None
121
+ if hess == 'autograd':
122
+ if self.use_hess_autograd:
123
+ if self.use_hessp:
124
+ hessp = partial(self._hessp, params=params, closure=closure)
125
+ hess = None
126
+ else:
127
+ hess = partial(self._hess, params=params, closure=closure)
128
+ # hess = 'autograd' but method doesn't use hess
129
+ else:
130
+ hess = None
131
+
132
+
133
+ if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
134
+ x0 = x0.astype(np.float64) # those methods error without this
135
+
136
+ res = scipy.optimize.minimize(
137
+ partial(self._objective, params = params, closure = closure),
138
+ x0 = x0,
139
+ method=self.method,
140
+ bounds=bounds,
141
+ constraints=self.constraints,
142
+ tol=self.tol,
143
+ callback=self.callback,
144
+ options=self.options,
145
+ jac = self.jac,
146
+ hess = hess,
147
+ hessp = hessp
148
+ )
149
+
150
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
151
+ return res.fun
@@ -0,0 +1,111 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import scipy.optimize
7
+ import torch
8
+
9
+ from ....utils import TensorList
10
+ from ..wrapper import WrapperBase
11
+ from .minimize import _use_jac_hess_hessp
12
+
13
+ Closure = Callable[[bool], Any]
14
+
15
+
16
+ class ScipySHGO(WrapperBase):
17
+ def __init__(
18
+ self,
19
+ params,
20
+ lb: float,
21
+ ub: float,
22
+ constraints = None,
23
+ n: int = 100,
24
+ iters: int = 1,
25
+ callback = None,
26
+ options: dict | None = None,
27
+ sampling_method: str = 'simplicial',
28
+ minimizer_kwargs: dict | None = None,
29
+ method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
30
+ 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
31
+ 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
32
+ 'trust-krylov'] | str = 'l-bfgs-b',
33
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
34
+ hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
35
+ use_hessp: bool = True,
36
+ ):
37
+ super().__init__(params, dict(lb=lb, ub=ub))
38
+
39
+ kwargs = locals().copy()
40
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__'], kwargs["options"]
41
+ del kwargs["method"], kwargs["jac"], kwargs["hess"], kwargs["use_hessp"], kwargs["minimizer_kwargs"]
42
+ self._kwargs = kwargs
43
+ self.minimizer_kwargs = minimizer_kwargs
44
+ self.options = options
45
+ self.method = method
46
+ self.hess = hess
47
+
48
+ self.jac, self.use_jac_autograd, self.use_hess_autograd, self.use_hessp = _use_jac_hess_hessp(method, jac, hess, use_hessp)
49
+
50
+
51
+ def _objective(self, x: np.ndarray, params: list[torch.Tensor], closure):
52
+ if self.use_jac_autograd:
53
+ f, g = self._f_g(x, params, closure)
54
+ if self.method.lower() == 'slsqp': g = g.astype(np.float64) # slsqp requires float64
55
+ return f, g
56
+
57
+ return self._f(x, params, closure)
58
+
59
+ def _hess(self, x: np.ndarray, params: list[torch.Tensor], closure):
60
+ f,g,H = self._f_g_H(x, params, closure)
61
+ return H
62
+
63
+ def _hessp(self, x: np.ndarray, p:np.ndarray, params: list[torch.Tensor], closure):
64
+ f,g,Hvp = self._f_g_Hvp(x, p, params, closure)
65
+ return Hvp
66
+
67
+ @torch.no_grad
68
+ def step(self, closure: Closure):
69
+ params = TensorList(self._get_params())
70
+ x0 = params.to_vec().numpy(force=True)
71
+ bounds = self._get_bounds()
72
+ assert bounds is not None
73
+
74
+ # determine hess argument
75
+ hess = self.hess
76
+ hessp = None
77
+ if hess == 'autograd':
78
+ if self.use_hess_autograd:
79
+ if self.use_hessp:
80
+ hessp = partial(self._hessp, params=params, closure=closure)
81
+ hess = None
82
+ else:
83
+ hess = partial(self._hess, params=params, closure=closure)
84
+ # hess = 'autograd' but method doesn't use hess
85
+ else:
86
+ hess = None
87
+
88
+
89
+ if self.method.lower() in ('tnc', 'slsqp'):
90
+ x0 = x0.astype(np.float64) # those methods error without this
91
+
92
+ minimizer_kwargs = self.minimizer_kwargs.copy() if self.minimizer_kwargs is not None else {}
93
+ minimizer_kwargs.setdefault("method", self.method)
94
+
95
+ options = self.options.copy() if self.options is not None else {}
96
+ minimizer_kwargs.setdefault("jac", self.jac)
97
+ minimizer_kwargs.setdefault("hess", hess)
98
+ minimizer_kwargs.setdefault("hessp", hessp)
99
+ minimizer_kwargs.setdefault("bounds", bounds)
100
+
101
+ res = scipy.optimize.shgo(
102
+ partial(self._objective, params=params, closure=closure),
103
+ bounds=bounds,
104
+ minimizer_kwargs=minimizer_kwargs,
105
+ options=options,
106
+ **self._kwargs
107
+ )
108
+
109
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
110
+ return res.fun
111
+
@@ -0,0 +1,121 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...utils import TensorList, tonumpy
8
+ from ...utils.derivatives import (
9
+ flatten_jacobian,
10
+ jacobian_and_hessian_mat_wrt,
11
+ jacobian_wrt,
12
+ )
13
+
14
+
15
+ class WrapperBase(torch.optim.Optimizer):
16
+ def __init__(self, params, defaults):
17
+ super().__init__(params, defaults)
18
+
19
+ @torch.no_grad
20
+ def _f(self, x: np.ndarray, params: list[torch.Tensor], closure) -> float:
21
+ # set params to x
22
+ params = TensorList(params)
23
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
24
+
25
+ return float(closure(False))
26
+
27
+ @torch.no_grad
28
+ def _fs(self, x: np.ndarray, params: list[torch.Tensor], closure) -> np.ndarray:
29
+ # set params to x
30
+ params = TensorList(params)
31
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
32
+
33
+ return tonumpy(closure(False)).reshape(-1)
34
+
35
+
36
+ @torch.no_grad
37
+ def _f_g(self, x: np.ndarray, params: list[torch.Tensor], closure) -> tuple[float, np.ndarray]:
38
+ # set params to x
39
+ params = TensorList(params)
40
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
41
+
42
+ # compute value and derivatives
43
+ with torch.enable_grad():
44
+ value = closure()
45
+ g = params.grad.fill_none(reference=params).to_vec()
46
+ return float(value), g.numpy(force=True)
47
+
48
+ @torch.no_grad
49
+ def _f_g_H(self, x: np.ndarray, params: list[torch.Tensor], closure) -> tuple[float, np.ndarray, np.ndarray]:
50
+ # set params to x
51
+ params = TensorList(params)
52
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
53
+
54
+ # compute value and derivatives
55
+ with torch.enable_grad():
56
+ value = closure(False)
57
+ g, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
58
+ return float(value), g.numpy(force=True), H.numpy(force=True)
59
+
60
+ @torch.no_grad
61
+ def _f_g_Hvp(self, x: np.ndarray, v: np.ndarray, params: list[torch.Tensor], closure) -> tuple[float, np.ndarray, np.ndarray]:
62
+ # set params to x
63
+ params = TensorList(params)
64
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
65
+
66
+ # compute value and derivatives
67
+ with torch.enable_grad():
68
+ value = closure(False)
69
+ grad = torch.autograd.grad(value, params, create_graph=True, allow_unused=True, materialize_grads=True)
70
+ flat_grad = torch.cat([i.reshape(-1) for i in grad])
71
+ Hvp = torch.autograd.grad(flat_grad, params, torch.as_tensor(v, device=flat_grad.device, dtype=flat_grad.dtype))[0]
72
+
73
+ return float(value), flat_grad.numpy(force=True), Hvp.numpy(force=True)
74
+
75
+ def _get_params(self) -> list[torch.Tensor]:
76
+ return [p for g in self.param_groups for p in g["params"]]
77
+
78
+ def _get_per_parameter_lb_ub(self):
79
+ # get per-parameter lb and ub
80
+ lb = []
81
+ ub = []
82
+ for group in self.param_groups:
83
+ lb.extend([group["lb"]] * len(group["params"]))
84
+ ub.extend([group["ub"]] * len(group["params"]))
85
+
86
+ return lb, ub
87
+
88
+ def _get_bounds(self):
89
+
90
+ # get per-parameter lb and ub
91
+ lb, ub = self._get_per_parameter_lb_ub()
92
+ if all(i is None for i in lb) and all(i is None for i in ub): return None
93
+
94
+ params = self._get_params()
95
+ bounds = []
96
+ for p, l, u in zip(params, lb, ub):
97
+ bounds.extend([(l, u)] * p.numel())
98
+
99
+ return bounds
100
+
101
+ def _get_lb_ub(self, ld:dict | None = None, ud: dict | None = None):
102
+ if ld is None: ld = {}
103
+ if ud is None: ud = {}
104
+
105
+ # get per-parameter lb and ub
106
+ lb, ub = self._get_per_parameter_lb_ub()
107
+
108
+ params = self._get_params()
109
+ lb_list = []
110
+ ub_list = []
111
+ for p, l, u in zip(params, lb, ub):
112
+ if l in ld: l = ld[l]
113
+ if u in ud: l = ud[u]
114
+ lb_list.extend([l] * p.numel())
115
+ ub_list.extend([u] * p.numel())
116
+
117
+ return lb_list, ub_list
118
+
119
+ @abstractmethod
120
+ def step(self, closure) -> Any: # pyright:ignore[reportIncompatibleMethodOverride] # pylint:disable=signature-differs
121
+ ...
@@ -1,33 +1,15 @@
1
1
  from . import tensorlist as tl
2
- from .compile import (
3
- _optional_compiler,
4
- benchmark_compile_cpu,
5
- benchmark_compile_cuda,
6
- enable_compilation,
7
- set_compilation,
8
- )
9
- from .numberlist import NumberList
10
- from .optimizer import (
11
- Init,
12
- ListLike,
13
- Optimizer,
14
- ParamFilter,
15
- get_group_vals,
16
- get_params,
17
- get_state_vals,
18
- unpack_states,
19
- )
20
- from .params import (
21
- Params,
22
- _add_defaults_to_param_groups_,
23
- _add_updates_grads_to_param_groups_,
24
- _copy_param_groups,
25
- _make_param_groups,
26
- )
2
+
3
+ from .metrics import evaluate_metric
4
+ from .numberlist import NumberList , maybe_numberlist
5
+ from .optimizer import unpack_states
6
+
7
+
27
8
  from .python_tools import (
28
9
  flatten,
29
10
  generic_eq,
30
11
  generic_ne,
12
+ generic_is_none,
31
13
  reduce_dim,
32
14
  safe_dict_update_,
33
15
  unpack_dicts,
File without changes
@@ -0,0 +1,122 @@
1
+ from functools import partial
2
+ from typing import Any, cast
3
+
4
+ import numpy as np
5
+ import torch
6
+ import tqdm
7
+
8
+
9
+ def generate_correlated_logistic_data(n_samples=2000, n_features=32, n_correlated_pairs=512, correlation=0.99, seed=0):
10
+ """Hard logistic regression dataset with correlated features"""
11
+ generator = np.random.default_rng(seed)
12
+
13
+ # ------------------------------------- X ------------------------------------ #
14
+ X = generator.standard_normal(size=(n_samples, n_features))
15
+ weights = generator.uniform(-2, 2, n_features)
16
+
17
+ used_pairs = []
18
+ for i in range(n_correlated_pairs):
19
+ idxs = None
20
+ while idxs is None or idxs in used_pairs:
21
+ idxs = tuple(generator.choice(n_features, size=2, replace=False).tolist())
22
+
23
+ used_pairs.append(idxs)
24
+ idx1, idx2 = idxs
25
+
26
+ noise = generator.standard_normal(n_samples) * np.sqrt(1 - correlation**2)
27
+ X[:, idx2] = correlation * X[:, idx1] + noise
28
+
29
+ w = generator.integers(1, 51)
30
+ weights[idx1] = w
31
+ weights[idx2] = -w
32
+
33
+ # ---------------------------------- logits ---------------------------------- #
34
+ logits = X @ weights
35
+ probabilities = 1 / (1 + np.exp(-logits))
36
+ y = generator.binomial(1, probabilities).astype(np.float32)
37
+
38
+ X = X - X.mean(0, keepdims=True)
39
+ X = X / X.std(0, keepdims=True)
40
+ return X, y
41
+
42
+
43
+ # if __name__ == '__main__':
44
+ # X, y = generate_correlated_logistic_data()
45
+
46
+ # plt.figure(figsize=(10, 8))
47
+ # sns.heatmap(pl.DataFrame(X).corr(), annot=True, cmap='coolwarm', fmt=".2f")
48
+ # plt.show()
49
+
50
+
51
+
52
+
53
+ def _tensorlist_equal(t1, t2):
54
+ return all(a == b for a, b in zip(t1, t2))
55
+
56
+ _placeholder = cast(Any, ...)
57
+
58
+ def run_logistic_regression(X: torch.Tensor, y: torch.Tensor, opt_fn, max_steps: int, tol:float=0, l1:float=0, l2:float=0, pbar:bool=False, *, _assert_on_evaluated_same_params: bool = False):
59
+ # ------------------------------- verify inputs ------------------------------ #
60
+ n_samples, n_features = X.size()
61
+
62
+ if y.ndim != 1: raise ValueError(f"y should be 1d, got {y.shape}")
63
+ if y.size(0) != n_samples: raise ValueError(f"y should have {n_samples} elements, got {y.shape}")
64
+ if y.device != X.device: raise ValueError(f"X and y should be on same device, got {X.device = }, {y.device = }")
65
+ device = X.device
66
+ dtype = X.dtype
67
+
68
+ # ---------------------------- model and criterion --------------------------- #
69
+ n_targets = int(y.amax()) + 1
70
+ binary = n_targets == 2
71
+
72
+ if binary:
73
+ criterion = torch.nn.functional.binary_cross_entropy_with_logits
74
+ model = torch.nn.Linear(n_features, 1).to(device=device, dtype=dtype)
75
+ y = y.to(dtype=dtype)
76
+ else:
77
+ model = torch.nn.Linear(n_features, n_targets).to(device=device, dtype=dtype)
78
+ criterion = torch.nn.functional.cross_entropy
79
+ y = y.long()
80
+
81
+ optimizer = opt_fn(list(model.parameters()))
82
+
83
+ # ---------------------------------- closure --------------------------------- #
84
+ def _l1_penalty():
85
+ return sum(p.abs().sum() for p in model.parameters())
86
+ def _l2_penalty():
87
+ return sum(p.square().sum() for p in model.parameters())
88
+
89
+ def closure(backward=True, evaluated_params: list = _placeholder, epoch: int = _placeholder):
90
+ y_hat = model(X)
91
+ loss = criterion(y_hat.squeeze(), y)
92
+
93
+ if l1 > 0: loss += _l1_penalty() * l1
94
+ if l2 > 0: loss += _l2_penalty() * l2
95
+
96
+ if backward:
97
+ optimizer.zero_grad()
98
+ loss.backward()
99
+
100
+ # here I also test to make sure the optimizer doesn't evaluate same parameters twice per step
101
+ # this is for tests
102
+ if _assert_on_evaluated_same_params:
103
+ for p in evaluated_params:
104
+ assert not _tensorlist_equal(p, model.parameters()), f"evaluated same parameters on epoch {epoch}"
105
+
106
+ evaluated_params.append([p.clone() for p in model.parameters()])
107
+
108
+ return loss
109
+
110
+ # --------------------------------- optimize --------------------------------- #
111
+ losses = []
112
+ epochs = tqdm.trange(max_steps, disable=not pbar)
113
+ for epoch in epochs:
114
+ evaluated_params = []
115
+ loss = float(optimizer.step(partial(closure, evaluated_params=evaluated_params, epoch=epoch)))
116
+
117
+ losses.append(loss)
118
+ epochs.set_postfix_str(f"{loss:.5f}")
119
+ if loss <= tol:
120
+ break
121
+
122
+ return losses