torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,12 +1,11 @@
1
- from typing import Literal
1
+ from typing import Literal, Any
2
2
  from collections.abc import Mapping, Callable
3
3
  from functools import partial
4
4
  import numpy as np
5
5
  import torch
6
6
 
7
7
  import nlopt
8
- from ...core import TensorListOptimizer, _ClosureType
9
- from ...tensorlist import TensorList
8
+ from ...utils import Optimizer, TensorList
10
9
 
11
10
  _ALGOS_LITERAL = Literal[
12
11
  "GN_DIRECT", # = _nlopt.GN_DIRECT
@@ -56,18 +55,21 @@ _ALGOS_LITERAL = Literal[
56
55
  ]
57
56
 
58
57
  def _ensure_float(x):
59
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
60
- if isinstance(x, np.ndarray): return x.item()
58
+ if isinstance(x, torch.Tensor): return float(x.detach().cpu().item())
59
+ if isinstance(x, np.ndarray): return float(x.item())
61
60
  return float(x)
62
61
 
63
62
  def _ensure_tensor(x):
64
- if isinstance(x, np.ndarray):
65
- x.setflags(write=True)
66
- return torch.from_numpy(x)
63
+ try:
64
+ if isinstance(x, np.ndarray): return torch.as_tensor(x.copy())
65
+ except SystemError:
66
+ return None
67
67
  return torch.tensor(x, dtype=torch.float32)
68
68
 
69
69
  inf = float('inf')
70
- class NLOptOptimizer(TensorListOptimizer):
70
+ Closure = Callable[[bool], Any]
71
+
72
+ class NLOptOptimizer(Optimizer):
71
73
  """Use nlopt as pytorch optimizer, with gradient supplied by pytorch autograd.
72
74
  Note that this performs full minimization on each step,
73
75
  so usually you would want to perform a single step, although performing multiple steps will refine the
@@ -119,8 +121,12 @@ class NLOptOptimizer(TensorListOptimizer):
119
121
 
120
122
  self._last_loss = None
121
123
 
122
- def _f(self, x: np.ndarray, grad: np.ndarray, closure: _ClosureType, params: TensorList):
123
- params.from_vec_(_ensure_tensor(x).to(params[0], copy=False))
124
+ def _f(self, x: np.ndarray, grad: np.ndarray, closure, params: TensorList):
125
+ t = _ensure_tensor(x)
126
+ if t is None:
127
+ if self.opt is not None: self.opt.force_stop()
128
+ return None
129
+ params.from_vec_(t.to(params[0], copy=False))
124
130
  if grad.size > 0:
125
131
  with torch.enable_grad(): loss = closure()
126
132
  self._last_loss = _ensure_float(loss)
@@ -131,12 +137,11 @@ class NLOptOptimizer(TensorListOptimizer):
131
137
  return self._last_loss
132
138
 
133
139
  @torch.no_grad
134
- def step(self, closure: _ClosureType): # pylint: disable = signature-differs
135
-
140
+ def step(self, closure: Closure): # pylint: disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
136
141
  params = self.get_params()
137
142
 
138
143
  # make bounds
139
- lb, ub = self.get_group_keys('lb', 'ub', cls=list)
144
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
140
145
  lower = []
141
146
  upper = []
142
147
  for p, l, u in zip(params, lb, ub):
@@ -145,9 +150,10 @@ class NLOptOptimizer(TensorListOptimizer):
145
150
  lower.extend([l] * p.numel())
146
151
  upper.extend([u] * p.numel())
147
152
 
148
- x0 = params.to_vec().detach().cpu().numpy()
153
+ x0 = params.to_vec().detach().cpu().numpy().astype(np.float64)
149
154
 
150
155
  self.opt = nlopt.opt(self.algorithm, x0.size)
156
+ self.opt.set_exceptions_enabled(False) # required
151
157
  self.opt.set_min_objective(partial(self._f, closure = closure, params = params))
152
158
  self.opt.set_lower_bounds(lower)
153
159
  self.opt.set_upper_bounds(upper)
@@ -160,6 +166,15 @@ class NLOptOptimizer(TensorListOptimizer):
160
166
  if self.xtol_abs is not None: self.opt.set_xtol_abs(self.xtol_abs)
161
167
  if self.maxtime is not None: self.opt.set_maxtime(self.maxtime)
162
168
 
163
- x = self.opt.optimize(x0)
169
+ self._last_loss = None
170
+ x = None
171
+ try:
172
+ x = self.opt.optimize(x0)
173
+ except SystemError:
174
+ pass
175
+ except Exception as e:
176
+ raise e from None
177
+
178
+ if self._last_loss is None or x is None: return closure(False)
164
179
  params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
165
180
  return self._last_loss
@@ -1,20 +1,15 @@
1
- from typing import Literal, Any
2
1
  from collections import abc
2
+ from collections.abc import Callable
3
3
  from functools import partial
4
+ from typing import Any, Literal
4
5
 
5
6
  import numpy as np
6
- import torch
7
-
8
7
  import scipy.optimize
8
+ import torch
9
9
 
10
- from ...core import _ClosureType, TensorListOptimizer
11
- from ...utils.derivatives import jacobian, jacobian_list_to_vec, hessian, hessian_list_to_mat, jacobian_and_hessian
12
- from ...modules import WrapClosure
13
- from ...modules.experimental.subspace import Projection, Proj2Masks, ProjGrad, ProjNormalize, Subspace
14
- from ...modules.second_order.newton import regularize_hessian_
15
- from ...tensorlist import TensorList
16
- from ..modular import Modular
17
-
10
+ from ...utils import Optimizer, TensorList
11
+ from ...utils.derivatives import jacobian_and_hessian_mat_wrt, jacobian_wrt
12
+ from ...modules.second_order.newton import tikhonov_
18
13
 
19
14
  def _ensure_float(x):
20
15
  if isinstance(x, torch.Tensor): return x.detach().cpu().item()
@@ -26,7 +21,17 @@ def _ensure_numpy(x):
26
21
  if isinstance(x, np.ndarray): return x
27
22
  return np.array(x)
28
23
 
29
- class ScipyMinimize(TensorListOptimizer):
24
+ def matrix_clamp(H: torch.Tensor, reg: float):
25
+ try:
26
+ eigvals, eigvecs = torch.linalg.eigh(H) # pylint:disable=not-callable
27
+ eigvals.clamp_(min=reg)
28
+ return eigvecs @ torch.diag(eigvals) @ eigvecs.mH
29
+ except Exception:
30
+ return H
31
+
32
+ Closure = Callable[[bool], Any]
33
+
34
+ class ScipyMinimize(Optimizer):
30
35
  """Use scipy.minimize.optimize as pytorch optimizer. Note that this performs full minimization on each step,
31
36
  so usually you would want to perform a single step, although performing multiple steps will refine the
32
37
  solution.
@@ -71,7 +76,8 @@ class ScipyMinimize(TensorListOptimizer):
71
76
  options = None,
72
77
  jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
73
78
  hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
74
- tikhonov: float | Literal['eig'] = 0,
79
+ tikhonov: float | None = 0,
80
+ min_eigval: float | None = None,
75
81
  ):
76
82
  defaults = dict(lb=lb, ub=ub)
77
83
  super().__init__(params, defaults)
@@ -79,11 +85,12 @@ class ScipyMinimize(TensorListOptimizer):
79
85
  self.constraints = constraints
80
86
  self.tol = tol
81
87
  self.callback = callback
88
+ self.min_eigval = min_eigval
82
89
  self.options = options
83
90
 
84
91
  self.jac = jac
85
92
  self.hess = hess
86
- self.tikhonov: float | Literal['eig'] = tikhonov
93
+ self.tikhonov: float | None = tikhonov
87
94
 
88
95
  self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
89
96
  'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
@@ -93,21 +100,22 @@ class ScipyMinimize(TensorListOptimizer):
93
100
  'newton-cg', 'dogleg', 'trust-ncg', 'trust-krylov', 'trust-exact'
94
101
  ]
95
102
 
103
+ # jac in scipy is '2-point', '3-point', 'cs', True or None.
96
104
  if self.jac == 'autograd':
97
105
  if self.use_jac_autograd: self.jac = True
98
106
  else: self.jac = None
99
107
 
100
108
 
101
- def _hess(self, x: np.ndarray, params: TensorList, closure: _ClosureType): # type:ignore
109
+ def _hess(self, x: np.ndarray, params: TensorList, closure):
102
110
  params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
103
111
  with torch.enable_grad():
104
112
  value = closure(False)
105
- H = hessian([value], wrt = params) # type:ignore
106
- Hmat = hessian_list_to_mat(H)
107
- regularize_hessian_(Hmat, self.tikhonov)
108
- return Hmat.detach().cpu().numpy()
113
+ _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
114
+ if self.tikhonov is not None: H = tikhonov_(H, self.tikhonov)
115
+ if self.min_eigval is not None: H = matrix_clamp(H, self.min_eigval)
116
+ return H.detach().cpu().numpy()
109
117
 
110
- def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
118
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
111
119
  # set params to x
112
120
  params.from_vec_(torch.from_numpy(x).to(params[0], copy=False))
113
121
 
@@ -118,7 +126,7 @@ class ScipyMinimize(TensorListOptimizer):
118
126
  return _ensure_float(closure(False))
119
127
 
120
128
  @torch.no_grad
121
- def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
129
+ def step(self, closure: Closure):# pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
122
130
  params = self.get_params()
123
131
 
124
132
  # determine hess argument
@@ -130,7 +138,7 @@ class ScipyMinimize(TensorListOptimizer):
130
138
  x0 = params.to_vec().detach().cpu().numpy()
131
139
 
132
140
  # make bounds
133
- lb, ub = self.get_group_keys('lb', 'ub', cls=list)
141
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
134
142
  bounds = []
135
143
  for p, l, u in zip(params, lb, ub):
136
144
  bounds.extend([(l, u)] * p.numel())
@@ -156,8 +164,8 @@ class ScipyMinimize(TensorListOptimizer):
156
164
 
157
165
 
158
166
 
159
- class ScipyRoot(TensorListOptimizer):
160
- """Find a root of a vector function (UNTESTED!).
167
+ class ScipyRootOptimization(Optimizer):
168
+ """Optimization via using scipy.root on gradients, mainly for experimenting!
161
169
 
162
170
  Args:
163
171
  params: iterable of parameters to optimize or dicts defining parameter groups.
@@ -196,94 +204,11 @@ class ScipyRoot(TensorListOptimizer):
196
204
  self.jac = jac
197
205
  if self.jac == 'autograd': self.jac = True
198
206
 
199
- def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
200
- # set params to x
201
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
202
-
203
- # return value and maybe gradients
204
- if self.jac:
205
- with torch.enable_grad():
206
- value = closure(False)
207
- if not isinstance(value, torch.Tensor):
208
- raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(value)}")
209
- jac = jacobian_list_to_vec(jacobian([value], wrt=params))
210
- return _ensure_numpy(value), jac.detach().cpu().numpy()
211
- return _ensure_numpy(closure(False))
212
-
213
- @torch.no_grad
214
- def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
215
- params = self.get_params()
216
-
217
- x0 = params.to_vec().detach().cpu().numpy()
218
-
219
- res = scipy.optimize.root(
220
- partial(self._objective, params = params, closure = closure),
221
- x0 = x0,
222
- method=self.method,
223
- tol=self.tol,
224
- callback=self.callback,
225
- options=self.options,
226
- jac = self.jac,
227
- )
228
-
229
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
230
- return res.fun
231
-
232
-
233
- class ScipyRootOptimization(TensorListOptimizer):
234
- """Optimization via finding roots of the gradient with `scipy.optimize.root` (for experiments, won't work well on most problems).
235
-
236
- Args:
237
- params: iterable of parameters to optimize or dicts defining parameter groups.
238
- method (str, optional): one of methods from https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root.html#scipy.optimize.root. Defaults to 'hybr'.
239
- tol (float | None, optional): tolerance. Defaults to None.
240
- callback (_type_, optional): callback. Defaults to None.
241
- options (_type_, optional): options for optimizer. Defaults to None.
242
- jac (Literal['2, optional): jacobian calculation method. Defaults to 'autograd'.
243
- tikhonov (float | Literal['eig'], optional): tikhonov regularization (only for 'hybr' and 'lm'). Defaults to 0.
244
- add_loss (float, optional): adds loss value to jacobian multiplied by this to try to avoid finding maxima. Defaults to 0.
245
- mul_loss (float, optional): multiplies jacobian by loss value multiplied by this to try to avoid finding maxima. Defaults to 0.
246
- """
247
- def __init__(
248
- self,
249
- params,
250
- method: Literal[
251
- "hybr",
252
- "lm",
253
- "broyden1",
254
- "broyden2",
255
- "anderson",
256
- "linearmixing",
257
- "diagbroyden",
258
- "excitingmixing",
259
- "krylov",
260
- "df-sane",
261
- ] = 'hybr',
262
- tol: float | None = None,
263
- callback = None,
264
- options = None,
265
- jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
266
- tikhonov: float | Literal['eig'] = 0,
267
- add_loss: float = 0,
268
- mul_loss: float = 0,
269
- ):
270
- super().__init__(params, {})
271
- self.method = method
272
- self.tol = tol
273
- self.callback = callback
274
- self.options = options
275
- self.value = None
276
- self.tikhonov: float | Literal['eig'] = tikhonov
277
- self.add_loss = add_loss
278
- self.mul_loss = mul_loss
279
-
280
- self.jac = jac == 'autograd'
281
-
282
207
  # those don't require jacobian
283
208
  if self.method.lower() in ('broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing', 'krylov', 'df-sane'):
284
209
  self.jac = None
285
210
 
286
- def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
211
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
287
212
  # set params to x
288
213
  params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
289
214
 
@@ -293,23 +218,16 @@ class ScipyRootOptimization(TensorListOptimizer):
293
218
  self.value = closure(False)
294
219
  if not isinstance(self.value, torch.Tensor):
295
220
  raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(self.value)}")
296
- jac_list, hess_list = jacobian_and_hessian([self.value], wrt=params)
297
- jac = jacobian_list_to_vec(jac_list)
298
- hess = hessian_list_to_mat(hess_list)
299
- regularize_hessian_(hess, self.tikhonov)
300
- if self.mul_loss != 0: jac *= self.value * self.mul_loss
301
- if self.add_loss != 0: jac += self.value * self.add_loss
302
- return jac.detach().cpu().numpy(), hess.detach().cpu().numpy()
221
+ g, H = jacobian_and_hessian_mat_wrt([self.value], wrt=params)
222
+ return g.detach().cpu().numpy(), H.detach().cpu().numpy()
303
223
 
304
224
  # return the gradients
305
225
  with torch.enable_grad(): self.value = closure()
306
226
  jac = params.ensure_grad_().grad.to_vec()
307
- if self.mul_loss != 0: jac *= self.value * self.mul_loss
308
- if self.add_loss != 0: jac += self.value * self.add_loss
309
227
  return jac.detach().cpu().numpy()
310
228
 
311
229
  @torch.no_grad
312
- def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
230
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
313
231
  params = self.get_params()
314
232
 
315
233
  x0 = params.to_vec().detach().cpu().numpy()
@@ -325,9 +243,11 @@ class ScipyRootOptimization(TensorListOptimizer):
325
243
  )
326
244
 
327
245
  params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
328
- return self.value
246
+ return res.fun
247
+
329
248
 
330
- class ScipyDE(TensorListOptimizer):
249
+
250
+ class ScipyDE(Optimizer):
331
251
  """Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
332
252
  so usually you would want to perform a single step. This also requires bounds to be specified.
333
253
 
@@ -374,12 +294,12 @@ class ScipyDE(TensorListOptimizer):
374
294
  self._kwargs = kwargs
375
295
  self._lb, self._ub = bounds
376
296
 
377
- def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
297
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
378
298
  params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
379
299
  return _ensure_float(closure(False))
380
300
 
381
301
  @torch.no_grad
382
- def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
302
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
383
303
  params = self.get_params()
384
304
 
385
305
  x0 = params.to_vec().detach().cpu().numpy()
@@ -396,44 +316,47 @@ class ScipyDE(TensorListOptimizer):
396
316
  return res.fun
397
317
 
398
318
 
399
- class ScipyMinimizeSubspace(Modular):
400
- """for experiments and won't work well on most problems.
401
319
 
402
- explanation - optimizes in a small subspace using scipy.optimize.minimize, but doesnt seem to work well"""
320
+ class ScipyDualAnnealing(Optimizer):
403
321
  def __init__(
404
322
  self,
405
323
  params,
406
- projections: Projection | abc.Iterable[Projection] = (
407
- Proj2Masks(5),
408
- ProjNormalize(
409
- ProjGrad(),
410
- )
411
- ),
412
- method=None,
413
- lb = None,
414
- ub = None,
415
- constraints=(),
416
- tol=None,
417
- callback=None,
418
- options=None,
419
- jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
420
- hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = '2-point',
324
+ bounds: tuple[float, float],
325
+ maxiter=1000,
326
+ minimizer_kwargs=None,
327
+ initial_temp=5230.0,
328
+ restart_temp_ratio=2.0e-5,
329
+ visit=2.62,
330
+ accept=-5.0,
331
+ maxfun=1e7,
332
+ rng=None,
333
+ no_local_search=False,
421
334
  ):
335
+ super().__init__(params, {})
422
336
 
423
- scopt = WrapClosure(
424
- ScipyMinimize,
425
- method = method,
426
- lb = lb,
427
- ub = ub,
428
- constraints = constraints,
429
- tol = tol,
430
- callback = callback,
431
- options = options,
432
- jac = jac,
433
- hess = hess
434
- )
435
- modules = [
436
- Subspace(scopt, projections),
437
- ]
337
+ kwargs = locals().copy()
338
+ del kwargs['self'], kwargs['params'], kwargs['bounds'], kwargs['__class__']
339
+ self._kwargs = kwargs
340
+ self._lb, self._ub = bounds
341
+
342
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
343
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
344
+ return _ensure_float(closure(False))
345
+
346
+ @torch.no_grad
347
+ def step(self, closure: Closure):# pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
348
+ params = self.get_params()
349
+
350
+ x0 = params.to_vec().detach().cpu().numpy()
351
+ bounds = [(self._lb, self._ub)] * len(x0)
352
+
353
+ res = scipy.optimize.dual_annealing(
354
+ partial(self._objective, params = params, closure = closure),
355
+ x0 = x0,
356
+ bounds=bounds,
357
+ **self._kwargs
358
+ )
359
+
360
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
361
+ return res.fun
438
362
 
439
- super().__init__(params, modules)
@@ -0,0 +1,27 @@
1
+ from . import tensorlist as tl
2
+ from .compile import _optional_compiler, benchmark_compile_cpu, benchmark_compile_cuda, set_compilation, enable_compilation
3
+ from .numberlist import NumberList
4
+ from .optimizer import (
5
+ Init,
6
+ ListLike,
7
+ Optimizer,
8
+ ParamFilter,
9
+ get_group_vals,
10
+ get_params,
11
+ get_state_vals,
12
+ grad_at_params,
13
+ grad_vec_at_params,
14
+ loss_at_params,
15
+ loss_grad_at_params,
16
+ loss_grad_vec_at_params,
17
+ )
18
+ from .params import (
19
+ Params,
20
+ _add_defaults_to_param_groups_,
21
+ _add_updates_grads_to_param_groups_,
22
+ _copy_param_groups,
23
+ _make_param_groups,
24
+ )
25
+ from .python_tools import flatten, generic_eq, reduce_dim
26
+ from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like
27
+ from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_