torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -7,24 +7,13 @@ import numpy as np
7
7
  import torch
8
8
  from directsearch.ds import DEFAULT_PARAMS
9
9
 
10
- from ...utils import Optimizer, TensorList
11
-
12
-
13
- def _ensure_float(x):
14
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
15
- if isinstance(x, np.ndarray): return x.item()
16
- return float(x)
17
-
18
- def _ensure_numpy(x):
19
- if isinstance(x, torch.Tensor): return x.detach().cpu()
20
- if isinstance(x, np.ndarray): return x
21
- return np.array(x)
22
-
10
+ from ...utils import TensorList
11
+ from .wrapper import WrapperBase
23
12
 
24
13
  Closure = Callable[[bool], Any]
25
14
 
26
15
 
27
- class DirectSearch(Optimizer):
16
+ class DirectSearch(WrapperBase):
28
17
  """Use directsearch as pytorch optimizer.
29
18
 
30
19
  Note that this performs full minimization on each step,
@@ -96,28 +85,23 @@ class DirectSearch(Optimizer):
96
85
  del kwargs['self'], kwargs['params'], kwargs['__class__']
97
86
  self._kwargs = kwargs
98
87
 
99
- def _objective(self, x: np.ndarray, params: TensorList, closure):
100
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
101
- return _ensure_float(closure(False))
102
-
103
88
  @torch.no_grad
104
89
  def step(self, closure: Closure):
105
- params = self.get_params()
106
-
107
- x0 = params.to_vec().detach().cpu().numpy()
90
+ params = TensorList(self._get_params())
91
+ x0 = params.to_vec().numpy(force=True)
108
92
 
109
93
  res = directsearch.solve(
110
- partial(self._objective, params = params, closure = closure),
94
+ partial(self._f, params=params, closure=closure),
111
95
  x0 = x0,
112
96
  **self._kwargs
113
97
  )
114
98
 
115
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
99
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
116
100
  return res.f
117
101
 
118
102
 
119
103
 
120
- class DirectSearchDS(Optimizer):
104
+ class DirectSearchDS(WrapperBase):
121
105
  def __init__(
122
106
  self,
123
107
  params,
@@ -139,26 +123,21 @@ class DirectSearchDS(Optimizer):
139
123
  del kwargs['self'], kwargs['params'], kwargs['__class__']
140
124
  self._kwargs = kwargs
141
125
 
142
- def _objective(self, x: np.ndarray, params: TensorList, closure):
143
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
144
- return _ensure_float(closure(False))
145
-
146
126
  @torch.no_grad
147
127
  def step(self, closure: Closure):
148
- params = self.get_params()
149
-
150
- x0 = params.to_vec().detach().cpu().numpy()
128
+ params = TensorList(self._get_params())
129
+ x0 = params.to_vec().numpy(force=True)
151
130
 
152
131
  res = directsearch.solve_directsearch(
153
- partial(self._objective, params = params, closure = closure),
132
+ partial(self._f, params = params, closure = closure),
154
133
  x0 = x0,
155
134
  **self._kwargs
156
135
  )
157
136
 
158
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
137
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
159
138
  return res.f
160
139
 
161
- class DirectSearchProbabilistic(Optimizer):
140
+ class DirectSearchProbabilistic(WrapperBase):
162
141
  def __init__(
163
142
  self,
164
143
  params,
@@ -179,27 +158,22 @@ class DirectSearchProbabilistic(Optimizer):
179
158
  del kwargs['self'], kwargs['params'], kwargs['__class__']
180
159
  self._kwargs = kwargs
181
160
 
182
- def _objective(self, x: np.ndarray, params: TensorList, closure):
183
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
184
- return _ensure_float(closure(False))
185
-
186
161
  @torch.no_grad
187
162
  def step(self, closure: Closure):
188
- params = self.get_params()
189
-
190
- x0 = params.to_vec().detach().cpu().numpy()
163
+ params = TensorList(self._get_params())
164
+ x0 = params.to_vec().numpy(force=True)
191
165
 
192
166
  res = directsearch.solve_probabilistic_directsearch(
193
- partial(self._objective, params = params, closure = closure),
167
+ partial(self._f, params = params, closure = closure),
194
168
  x0 = x0,
195
169
  **self._kwargs
196
170
  )
197
171
 
198
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
172
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
199
173
  return res.f
200
174
 
201
175
 
202
- class DirectSearchSubspace(Optimizer):
176
+ class DirectSearchSubspace(WrapperBase):
203
177
  def __init__(
204
178
  self,
205
179
  params,
@@ -223,28 +197,23 @@ class DirectSearchSubspace(Optimizer):
223
197
  del kwargs['self'], kwargs['params'], kwargs['__class__']
224
198
  self._kwargs = kwargs
225
199
 
226
- def _objective(self, x: np.ndarray, params: TensorList, closure):
227
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
228
- return _ensure_float(closure(False))
229
-
230
200
  @torch.no_grad
231
201
  def step(self, closure: Closure):
232
- params = self.get_params()
233
-
234
- x0 = params.to_vec().detach().cpu().numpy()
202
+ params = TensorList(self._get_params())
203
+ x0 = params.to_vec().numpy(force=True)
235
204
 
236
205
  res = directsearch.solve_subspace_directsearch(
237
- partial(self._objective, params = params, closure = closure),
206
+ partial(self._f, params = params, closure = closure),
238
207
  x0 = x0,
239
208
  **self._kwargs
240
209
  )
241
210
 
242
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
211
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
243
212
  return res.f
244
213
 
245
214
 
246
215
 
247
- class DirectSearchSTP(Optimizer):
216
+ class DirectSearchSTP(WrapperBase):
248
217
  def __init__(
249
218
  self,
250
219
  params,
@@ -260,21 +229,16 @@ class DirectSearchSTP(Optimizer):
260
229
  del kwargs['self'], kwargs['params'], kwargs['__class__']
261
230
  self._kwargs = kwargs
262
231
 
263
- def _objective(self, x: np.ndarray, params: TensorList, closure):
264
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
265
- return _ensure_float(closure(False))
266
-
267
232
  @torch.no_grad
268
233
  def step(self, closure: Closure):
269
- params = self.get_params()
270
-
271
- x0 = params.to_vec().detach().cpu().numpy()
234
+ params = TensorList(self._get_params())
235
+ x0 = params.to_vec().numpy(force=True)
272
236
 
273
237
  res = directsearch.solve_stp(
274
- partial(self._objective, params = params, closure = closure),
238
+ partial(self._f, params = params, closure = closure),
275
239
  x0 = x0,
276
240
  **self._kwargs
277
241
  )
278
242
 
279
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
243
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
280
244
  return res.f
@@ -9,20 +9,15 @@ import fcmaes
9
9
  import fcmaes.optimizer
10
10
  import fcmaes.retry
11
11
 
12
- from ...utils import Optimizer, TensorList
12
+ from ...utils import TensorList
13
+ from .wrapper import WrapperBase
13
14
 
14
15
  Closure = Callable[[bool], Any]
15
16
 
16
-
17
- def _ensure_float(x) -> float:
18
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
19
- if isinstance(x, np.ndarray): return float(x.item())
20
- return float(x)
21
-
22
17
  def silence_fcmaes():
23
18
  fcmaes.retry.logger.disable('fcmaes')
24
19
 
25
- class FcmaesWrapper(Optimizer):
20
+ class FcmaesWrapper(WrapperBase):
26
21
  """Use fcmaes as pytorch optimizer. Particularly fcmaes has BITEOPT which appears to win in many benchmarks.
27
22
 
28
23
  Note that this performs full minimization on each step, so only perform one step with this.
@@ -42,7 +37,7 @@ class FcmaesWrapper(Optimizer):
42
37
  CMA-ES population size used for all CMA-ES runs.
43
38
  Not used for differential evolution.
44
39
  Ignored if parameter optimizer is defined. Defaults to 31.
45
- capacity (int | None, optional): capacity of the evaluation store.. Defaults to 500.
40
+ capacity (int | None, optional): capacity of the evaluation store. Defaults to 500.
46
41
  stop_fitness (float | None, optional):
47
42
  Limit for fitness value. optimization runs terminate if this value is reached. Defaults to -np.inf.
48
43
  statistic_num (int | None, optional):
@@ -61,46 +56,30 @@ class FcmaesWrapper(Optimizer):
61
56
  popsize: int | None = 31,
62
57
  capacity: int | None = 500,
63
58
  stop_fitness: float | None = -np.inf,
64
- statistic_num: int | None = 0
59
+ statistic_num: int | None = 0,
60
+ silence: bool = True,
65
61
  ):
66
- super().__init__(params, lb=lb, ub=ub)
67
- silence_fcmaes()
62
+ super().__init__(params, dict(lb=lb,ub=ub))
63
+ if silence:
64
+ silence_fcmaes()
68
65
  kwargs = locals().copy()
69
- del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
66
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__'], kwargs["silence"]
70
67
  self._kwargs = kwargs
71
68
  self._kwargs['workers'] = 1
72
69
 
73
- def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
74
- if self.raised: return np.inf
75
- try:
76
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
77
- return _ensure_float(closure(False))
78
- except Exception as e:
79
- # ha ha, I found a way to make exceptions work in fcmaes and scipy direct
80
- self.e = e
81
- self.raised = True
82
- return np.inf
83
70
 
84
71
  @torch.no_grad
85
72
  def step(self, closure: Closure):
86
- self.raised = False
87
- self.e = None
88
73
 
89
- params = self.get_params()
90
-
91
- lb, ub = self.group_vals('lb', 'ub', cls=list)
92
- bounds = []
93
- for p, l, u in zip(params, lb, ub):
94
- bounds.extend([[l, u]] * p.numel())
74
+ params = TensorList(self._get_params())
75
+ bounds = self._get_bounds()
95
76
 
96
77
  res = fcmaes.retry.minimize(
97
- partial(self._objective, params=params, closure=closure), # pyright:ignore[reportArgumentType]
78
+ partial(self._f, params=params, closure=closure), # pyright:ignore[reportArgumentType]
98
79
  bounds=bounds, # pyright:ignore[reportArgumentType]
99
80
  **self._kwargs
100
81
  )
101
82
 
102
- params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
103
-
104
- if self.e is not None: raise self.e from None
83
+ params.from_vec_(torch.as_tensor(res.x, device = params[0].device, dtype=params[0].dtype))
105
84
  return res.fun
106
85
 
@@ -6,24 +6,13 @@ import numpy as np
6
6
  import torch
7
7
  from mads.mads import orthomads
8
8
 
9
- from ...utils import Optimizer, TensorList
10
-
11
-
12
- def _ensure_float(x):
13
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
14
- if isinstance(x, np.ndarray): return x.item()
15
- return float(x)
16
-
17
- def _ensure_numpy(x):
18
- if isinstance(x, torch.Tensor): return x.detach().cpu()
19
- if isinstance(x, np.ndarray): return x
20
- return np.array(x)
21
-
9
+ from ...utils import TensorList
10
+ from .wrapper import WrapperBase
22
11
 
23
12
  Closure = Callable[[bool], Any]
24
13
 
25
14
 
26
- class MADS(Optimizer):
15
+ class MADS(WrapperBase):
27
16
  """Use mads.orthomads as pytorch optimizer.
28
17
 
29
18
  Note that this performs full minimization on each step,
@@ -53,37 +42,28 @@ class MADS(Optimizer):
53
42
  displog = False,
54
43
  savelog = False,
55
44
  ):
56
- super().__init__(params, lb=lb, ub=ub)
45
+ super().__init__(params, dict(lb=lb, ub=ub))
57
46
 
58
47
  kwargs = locals().copy()
59
48
  del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
60
49
  self._kwargs = kwargs
61
50
 
62
- def _objective(self, x: np.ndarray, params: TensorList, closure):
63
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
64
- return _ensure_float(closure(False))
65
51
 
66
52
  @torch.no_grad
67
53
  def step(self, closure: Closure):
68
- params = self.get_params()
69
-
70
- x0 = params.to_vec().detach().cpu().numpy()
54
+ params = TensorList(self._get_params())
55
+ x0 = params.to_vec().numpy(force=True)
56
+ lb, ub = self._get_lb_ub()
71
57
 
72
- lb, ub = self.group_vals('lb', 'ub', cls=list)
73
- bounds_lower = []
74
- bounds_upper = []
75
- for p, l, u in zip(params, lb, ub):
76
- bounds_lower.extend([l] * p.numel())
77
- bounds_upper.extend([u] * p.numel())
78
58
 
79
59
  f, x = orthomads(
80
60
  design_variables=x0,
81
- bounds_upper=np.asarray(bounds_upper),
82
- bounds_lower=np.asarray(bounds_lower),
83
- objective_function=partial(self._objective, params = params, closure = closure),
61
+ bounds_upper=np.asarray(ub),
62
+ bounds_lower=np.asarray(lb),
63
+ objective_function=partial(self._f, params=params, closure=closure),
84
64
  **self._kwargs
85
65
  )
86
66
 
87
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
67
+ params.from_vec_(torch.as_tensor(x, device = params[0].device, dtype=params[0].dtype,))
88
68
  return f
89
69
 
@@ -0,0 +1,66 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ...utils import TensorList
9
+ from .wrapper import WrapperBase
10
+
11
+ Closure = Callable[[bool], Any]
12
+
13
+ class MoorsWrapper(WrapperBase):
14
+ """Use moo-rs (pymoors) is PyTorch optimizer.
15
+
16
+ Note that this performs full minimization on each step,
17
+ so usually you would want to perform a single step.
18
+
19
+ To use this, define a function that accepts fitness function and number of variables and returns a pymoors algorithm:
20
+
21
+ ```python
22
+ alg_fn = lambda fitness_fn, num_vars: pymoors.Nsga2(
23
+ fitness_fn=fitness_fn,
24
+ num_vars=num_vars,
25
+ num_iterations=100,
26
+ sampler = pymoors.RandomSamplingFloat(min=-3, max=3),
27
+ crossover = pymoors.SinglePointBinaryCrossover(),
28
+ mutation = pymoors.GaussianMutation(gene_mutation_rate=1e-2, sigma=0.1),
29
+ population_size = 32,
30
+ num_offsprings = 32,
31
+ )
32
+
33
+ optimizer = MoorsWrapper(model.parameters(), alg_fn)
34
+ ```
35
+
36
+ All algorithms in pymoors have slightly different APIs, refer to their docs.
37
+
38
+ """
39
+ def __init__(
40
+ self,
41
+ params,
42
+ algorithm_fn: Callable[[Callable[[np.ndarray], np.ndarray], int], Any]
43
+ ):
44
+ super().__init__(params, {})
45
+ self._algorithm_fn = algorithm_fn
46
+
47
+ def _objective(self, x: np.ndarray, params, closure):
48
+ fs = []
49
+ for x_i in x:
50
+ f_i = self._fs(x_i, params=params, closure=closure)
51
+ fs.append(f_i)
52
+ return np.stack(fs, dtype=np.float64) # pymoors needs float64
53
+
54
+ @torch.no_grad
55
+ def step(self, closure: Closure):
56
+ params = TensorList(self._get_params())
57
+ objective = partial(self._objective, params=params, closure=closure)
58
+
59
+ algorithm = self._algorithm_fn(objective, params.global_numel())
60
+
61
+ algorithm.run()
62
+ pop = algorithm.population
63
+
64
+ params.from_vec_(torch.as_tensor(pop.best[0].genes, device = params[0].device, dtype=params[0].dtype,))
65
+ return pop.best[0].fitness
66
+
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  import nevergrad as ng
8
8
 
9
- from ...utils import Optimizer
9
+ from .wrapper import WrapperBase
10
10
 
11
11
 
12
12
  def _ensure_float(x) -> float:
@@ -14,7 +14,7 @@ def _ensure_float(x) -> float:
14
14
  if isinstance(x, np.ndarray): return float(x.item())
15
15
  return float(x)
16
16
 
17
- class NevergradWrapper(Optimizer):
17
+ class NevergradWrapper(WrapperBase):
18
18
  """Use nevergrad optimizer as pytorch optimizer.
19
19
  Note that it is recommended to specify `budget` to the number of iterations you expect to run,
20
20
  as some nevergrad optimizers will error without it.
@@ -72,7 +72,7 @@ class NevergradWrapper(Optimizer):
72
72
 
73
73
  @torch.no_grad
74
74
  def step(self, closure): # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
75
- params = self.get_params()
75
+ params = self._get_params()
76
76
  if self.opt is None:
77
77
  ng_params = []
78
78
  for group in self.param_groups:
@@ -95,7 +95,7 @@ class NevergradWrapper(Optimizer):
95
95
 
96
96
  x: ng.p.Tuple = self.opt.ask() # type:ignore
97
97
  for cur, new in zip(params, x):
98
- cur.set_(torch.from_numpy(new.value).to(dtype=cur.dtype, device=cur.device, copy=False).reshape_as(cur)) # type:ignore
98
+ cur.set_(torch.as_tensor(new.value, dtype=cur.dtype, device=cur.device).reshape_as(cur)) # type:ignore
99
99
 
100
100
  loss = closure(False)
101
101
  self.opt.tell(x, _ensure_float(loss))
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  from typing import Literal, Any
2
3
  from collections.abc import Mapping, Callable
3
4
  from functools import partial
@@ -5,7 +6,8 @@ import numpy as np
5
6
  import torch
6
7
 
7
8
  import nlopt
8
- from ...utils import Optimizer, TensorList
9
+ from ...utils import TensorList
10
+ from .wrapper import WrapperBase
9
11
 
10
12
  _ALGOS_LITERAL = Literal[
11
13
  "GN_DIRECT", # = _nlopt.GN_DIRECT
@@ -69,14 +71,14 @@ def _ensure_tensor(x):
69
71
  inf = float('inf')
70
72
  Closure = Callable[[bool], Any]
71
73
 
72
- class NLOptWrapper(Optimizer):
74
+ class NLOptWrapper(WrapperBase):
73
75
  """Use nlopt as pytorch optimizer, with gradient supplied by pytorch autograd.
74
76
  Note that this performs full minimization on each step,
75
77
  so usually you would want to perform a single step, although performing multiple steps will refine the
76
78
  solution.
77
79
 
78
80
  Args:
79
- params: iterable of parameters to optimize or dicts defining parameter groups.
81
+ params (Iterable): iterable of parameters to optimize or dicts defining parameter groups.
80
82
  algorithm (int | _ALGOS_LITERAL): optimization algorithm from https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/
81
83
  maxeval (int | None):
82
84
  maximum allowed function evaluations, set to None to disable. But some stopping criterion
@@ -96,21 +98,30 @@ class NLOptWrapper(Optimizer):
96
98
  algorithm: int | _ALGOS_LITERAL,
97
99
  lb: float | None = None,
98
100
  ub: float | None = None,
99
- maxeval: int | None = 10000, # None can stall on some algos and because they are threaded C you can't even interrupt them
101
+ maxeval: int | None = None, # None can stall on some algos and because they are threaded C you can't even interrupt them
100
102
  stopval: float | None = None,
101
103
  ftol_rel: float | None = None,
102
104
  ftol_abs: float | None = None,
103
105
  xtol_rel: float | None = None,
104
106
  xtol_abs: float | None = None,
105
107
  maxtime: float | None = None,
108
+ require_criterion: bool = True,
106
109
  ):
110
+ if require_criterion:
111
+ if all(i is None for i in (maxeval, stopval, ftol_abs, ftol_rel, xtol_abs, xtol_rel)):
112
+ raise RuntimeError(
113
+ "Specify at least one stopping criterion out of "
114
+ "(maxeval, stopval, ftol_rel, ftol_abs, xtol_rel, xtol_abs, maxtime). "
115
+ "Pass `require_criterion=False` to suppress this error."
116
+ )
117
+
107
118
  defaults = dict(lb=lb, ub=ub)
108
119
  super().__init__(params, defaults)
109
120
 
110
121
  self.opt: nlopt.opt | None = None
122
+ self.algorithm_name: str | int = algorithm
111
123
  if isinstance(algorithm, str): algorithm = getattr(nlopt, algorithm.upper())
112
124
  self.algorithm: int = algorithm # type:ignore
113
- self.algorithm_name: str | None = None
114
125
 
115
126
  self.maxeval = maxeval; self.stopval = stopval
116
127
  self.ftol_rel = ftol_rel; self.ftol_abs = ftol_abs
@@ -119,7 +130,7 @@ class NLOptWrapper(Optimizer):
119
130
 
120
131
  self._last_loss = None
121
132
 
122
- def _f(self, x: np.ndarray, grad: np.ndarray, closure, params: TensorList):
133
+ def _objective(self, x: np.ndarray, grad: np.ndarray, closure, params: TensorList):
123
134
  if self.raised:
124
135
  if self.opt is not None: self.opt.force_stop()
125
136
  return np.inf
@@ -132,7 +143,7 @@ class NLOptWrapper(Optimizer):
132
143
  if grad.size > 0:
133
144
  with torch.enable_grad(): loss = closure()
134
145
  self._last_loss = _ensure_float(loss)
135
- grad[:] = params.ensure_grad_().grad.to_vec().reshape(grad.shape).detach().cpu().numpy()
146
+ grad[:] = params.grad.fill_none_(reference=params).to_vec().reshape(grad.shape).numpy(force=True)
136
147
  return self._last_loss
137
148
 
138
149
  self._last_loss = _ensure_float(closure(False))
@@ -147,25 +158,20 @@ class NLOptWrapper(Optimizer):
147
158
  def step(self, closure: Closure): # pylint: disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
148
159
  self.e = None
149
160
  self.raised = False
150
- params = self.get_params()
151
-
152
- # make bounds
153
- lb, ub = self.group_vals('lb', 'ub', cls=list)
154
- lower = []
155
- upper = []
156
- for p, l, u in zip(params, lb, ub):
157
- if l is None: l = -inf
158
- if u is None: u = inf
159
- lower.extend([l] * p.numel())
160
- upper.extend([u] * p.numel())
161
+ params = TensorList(self._get_params())
162
+ x0 = params.to_vec().numpy(force=True)
161
163
 
162
- x0 = params.to_vec().detach().cpu().numpy().astype(np.float64)
164
+ plb, pub = self._get_per_parameter_lb_ub()
165
+ if all(i is None for i in plb) and all(i is None for i in pub):
166
+ lb = ub = None
167
+ else:
168
+ lb, ub = self._get_lb_ub(ld = {None: -np.inf}, ud = {None: np.inf})
163
169
 
164
170
  self.opt = nlopt.opt(self.algorithm, x0.size)
165
171
  self.opt.set_exceptions_enabled(False) # required
166
- self.opt.set_min_objective(partial(self._f, closure = closure, params = params))
167
- self.opt.set_lower_bounds(lower)
168
- self.opt.set_upper_bounds(upper)
172
+ self.opt.set_min_objective(partial(self._objective, closure = closure, params = params))
173
+ if lb is not None: self.opt.set_lower_bounds(np.asarray(lb, dtype=x0.dtype))
174
+ if ub is not None: self.opt.set_upper_bounds(np.asarray(ub, dtype=x0.dtype))
169
175
 
170
176
  if self.maxeval is not None: self.opt.set_maxeval(self.maxeval)
171
177
  if self.stopval is not None: self.opt.set_stopval(self.stopval)
@@ -179,12 +185,12 @@ class NLOptWrapper(Optimizer):
179
185
  x = None
180
186
  try:
181
187
  x = self.opt.optimize(x0)
182
- except SystemError:
183
- pass
188
+ # except SystemError as s:
189
+ # warnings.warn(f"{self.algorithm_name} raised {s}")
184
190
  except Exception as e:
185
191
  raise e from None
186
192
 
187
- if x is not None: params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
193
+ if x is not None: params.from_vec_(torch.as_tensor(x, device = params[0].device, dtype=params[0].dtype))
188
194
  if self.e is not None: raise self.e from None
189
195
 
190
196
  if self._last_loss is None or x is None: return closure(False)
@@ -1,23 +1,16 @@
1
- import typing
2
- from collections import abc
3
-
4
- import numpy as np
1
+ import optuna
5
2
  import torch
6
3
 
7
- import optuna
4
+ from ...utils import TensorList, tofloat, totensor
5
+ from .wrapper import WrapperBase
8
6
 
9
- from ...utils import Optimizer, totensor, tofloat
10
7
 
11
8
  def silence_optuna():
12
9
  optuna.logging.set_verbosity(optuna.logging.WARNING)
13
10
 
14
- def _ensure_float(x) -> float:
15
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
16
- if isinstance(x, np.ndarray): return float(x.item())
17
- return float(x)
18
11
 
19
12
 
20
- class OptunaSampler(Optimizer):
13
+ class OptunaSampler(WrapperBase):
21
14
  """Optimize your next SOTA model using hyperparameter optimization.
22
15
 
23
16
  Note - optuna is surprisingly scalable to large number of parameters (up to 10,000), despite literally requiring a for-loop because it only supports scalars. Default TPESampler is good for BBO. Maybe not for NNs...
@@ -38,7 +31,7 @@ class OptunaSampler(Optimizer):
38
31
  silence: bool = True,
39
32
  ):
40
33
  if silence: silence_optuna()
41
- super().__init__(params, lb=lb, ub=ub)
34
+ super().__init__(params, dict(lb=lb, ub=ub))
42
35
 
43
36
  if isinstance(sampler, type): sampler = sampler()
44
37
  self.sampler = sampler
@@ -47,7 +40,7 @@ class OptunaSampler(Optimizer):
47
40
  @torch.no_grad
48
41
  def step(self, closure):
49
42
 
50
- params = self.get_params()
43
+ params = TensorList(self._get_params())
51
44
  if self.study is None:
52
45
  self.study = optuna.create_study(sampler=self.sampler)
53
46