torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -69,7 +69,7 @@ def _ensure_tensor(x):
69
69
  inf = float('inf')
70
70
  Closure = Callable[[bool], Any]
71
71
 
72
- class NLOptOptimizer(Optimizer):
72
+ class NLOptWrapper(Optimizer):
73
73
  """Use nlopt as pytorch optimizer, with gradient supplied by pytorch autograd.
74
74
  Note that this performs full minimization on each step,
75
75
  so usually you would want to perform a single step, although performing multiple steps will refine the
@@ -96,9 +96,9 @@ class NLOptOptimizer(Optimizer):
96
96
  self,
97
97
  params,
98
98
  algorithm: int | _ALGOS_LITERAL,
99
- maxeval: int | None,
100
99
  lb: float | None = None,
101
100
  ub: float | None = None,
101
+ maxeval: int | None = 10000, # None can stall on some algos and because they are threaded C you can't even interrupt them
102
102
  stopval: float | None = None,
103
103
  ftol_rel: float | None = None,
104
104
  ftol_abs: float | None = None,
@@ -122,22 +122,33 @@ class NLOptOptimizer(Optimizer):
122
122
  self._last_loss = None
123
123
 
124
124
  def _f(self, x: np.ndarray, grad: np.ndarray, closure, params: TensorList):
125
- t = _ensure_tensor(x)
126
- if t is None:
125
+ if self.raised:
127
126
  if self.opt is not None: self.opt.force_stop()
128
- return None
129
- params.from_vec_(t.to(params[0], copy=False))
130
- if grad.size > 0:
131
- with torch.enable_grad(): loss = closure()
132
- self._last_loss = _ensure_float(loss)
133
- grad[:] = params.ensure_grad_().grad.to_vec().reshape(grad.shape).detach().cpu().numpy()
127
+ return np.inf
128
+ try:
129
+ t = _ensure_tensor(x)
130
+ if t is None:
131
+ if self.opt is not None: self.opt.force_stop()
132
+ return None
133
+ params.from_vec_(t.to(params[0], copy=False))
134
+ if grad.size > 0:
135
+ with torch.enable_grad(): loss = closure()
136
+ self._last_loss = _ensure_float(loss)
137
+ grad[:] = params.ensure_grad_().grad.to_vec().reshape(grad.shape).detach().cpu().numpy()
138
+ return self._last_loss
139
+
140
+ self._last_loss = _ensure_float(closure(False))
134
141
  return self._last_loss
135
-
136
- self._last_loss = _ensure_float(closure(False))
137
- return self._last_loss
142
+ except Exception as e:
143
+ self.e = e
144
+ self.raised = True
145
+ if self.opt is not None: self.opt.force_stop()
146
+ return np.inf
138
147
 
139
148
  @torch.no_grad
140
149
  def step(self, closure: Closure): # pylint: disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
150
+ self.e = None
151
+ self.raised = False
141
152
  params = self.get_params()
142
153
 
143
154
  # make bounds
@@ -175,6 +186,9 @@ class NLOptOptimizer(Optimizer):
175
186
  except Exception as e:
176
187
  raise e from None
177
188
 
189
+ if x is not None: params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
190
+ if self.e is not None: raise self.e from None
191
+
178
192
  if self._last_loss is None or x is None: return closure(False)
179
- params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
193
+
180
194
  return self._last_loss
@@ -0,0 +1,70 @@
1
+ import typing
2
+ from collections import abc
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ import optuna
8
+
9
+ from ...utils import Optimizer
10
+
11
+ def silence_optuna():
12
+ optuna.logging.set_verbosity(optuna.logging.WARNING)
13
+
14
+ def _ensure_float(x) -> float:
15
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
16
+ if isinstance(x, np.ndarray): return float(x.item())
17
+ return float(x)
18
+
19
+
20
+ class OptunaSampler(Optimizer):
21
+ """Optimize your next SOTA model using hyperparameter optimization.
22
+
23
+ Note - optuna is surprisingly scalable to large number of parameters (up to 10,000), despite literally requiring a for-loop because it only supports scalars. Default TPESampler is good for BBO. Maybe not for NNs...
24
+
25
+ Args:
26
+ params: iterable of parameters to optimize or dicts defining parameter groups.
27
+ lb (float): lower bounds.
28
+ ub (float): upper bounds.
29
+ sampler (optuna.samplers.BaseSampler | type[optuna.samplers.BaseSampler] | None, optional): sampler. Defaults to None.
30
+ silence (bool, optional): makes optuna not write a lot of very useful information to console. Defaults to True.
31
+ """
32
+ def __init__(
33
+ self,
34
+ params,
35
+ lb: float,
36
+ ub: float,
37
+ sampler: "optuna.samplers.BaseSampler | type[optuna.samplers.BaseSampler] | None" = None,
38
+ silence: bool = True,
39
+ ):
40
+ if silence: silence_optuna()
41
+ super().__init__(params, lb=lb, ub=ub)
42
+
43
+ if isinstance(sampler, type): sampler = sampler()
44
+ self.sampler = sampler
45
+ self.study = None
46
+
47
+ @torch.no_grad
48
+ def step(self, closure):
49
+
50
+ params = self.get_params()
51
+ if self.study is None:
52
+ self.study = optuna.create_study(sampler=self.sampler)
53
+
54
+ # some optuna samplers use torch
55
+ with torch.enable_grad():
56
+ trial = self.study.ask()
57
+
58
+ suggested = []
59
+ for gi,g in enumerate(self.param_groups):
60
+ for pi,p in enumerate(g['params']):
61
+ lb, ub = g['lb'], g['ub']
62
+ suggested.extend(trial.suggest_float(f'g{gi}_p{pi}_w{i}', lb, ub) for i in range(p.numel()))
63
+
64
+ vec = torch.as_tensor(suggested).to(params[0])
65
+ params.from_vec_(vec)
66
+
67
+ loss = closure()
68
+ with torch.enable_grad(): self.study.tell(trial, loss)
69
+
70
+ return loss
@@ -11,9 +11,9 @@ from ...utils import Optimizer, TensorList
11
11
  from ...utils.derivatives import jacobian_and_hessian_mat_wrt, jacobian_wrt
12
12
  from ...modules.second_order.newton import tikhonov_
13
13
 
14
- def _ensure_float(x):
14
+ def _ensure_float(x) -> float:
15
15
  if isinstance(x, torch.Tensor): return x.detach().cpu().item()
16
- if isinstance(x, np.ndarray): return x.item()
16
+ if isinstance(x, np.ndarray): return float(x.item())
17
17
  return float(x)
18
18
 
19
19
  def _ensure_numpy(x):
@@ -139,9 +139,11 @@ class ScipyMinimize(Optimizer):
139
139
 
140
140
  # make bounds
141
141
  lb, ub = self.group_vals('lb', 'ub', cls=list)
142
- bounds = []
143
- for p, l, u in zip(params, lb, ub):
144
- bounds.extend([(l, u)] * p.numel())
142
+ bounds = None
143
+ if any(b is not None for b in lb) or any(b is not None for b in ub):
144
+ bounds = []
145
+ for p, l, u in zip(params, lb, ub):
146
+ bounds.extend([(l, u)] * p.numel())
145
147
 
146
148
  if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
147
149
  x0 = x0.astype(np.float64) # those methods error without this
@@ -265,7 +267,8 @@ class ScipyDE(Optimizer):
265
267
  def __init__(
266
268
  self,
267
269
  params,
268
- bounds: tuple[float,float],
270
+ lb: float,
271
+ ub: float,
269
272
  strategy: Literal['best1bin', 'best1exp', 'rand1bin', 'rand1exp', 'rand2bin', 'rand2exp',
270
273
  'randtobest1bin', 'randtobest1exp', 'currenttobest1bin', 'currenttobest1exp',
271
274
  'best2exp', 'best2bin'] = 'best1bin',
@@ -287,12 +290,11 @@ class ScipyDE(Optimizer):
287
290
  integrality = None,
288
291
 
289
292
  ):
290
- super().__init__(params, {})
293
+ super().__init__(params, lb=lb, ub=ub)
291
294
 
292
295
  kwargs = locals().copy()
293
- del kwargs['self'], kwargs['params'], kwargs['bounds'], kwargs['__class__']
296
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
294
297
  self._kwargs = kwargs
295
- self._lb, self._ub = bounds
296
298
 
297
299
  def _objective(self, x: np.ndarray, params: TensorList, closure):
298
300
  params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
@@ -303,7 +305,11 @@ class ScipyDE(Optimizer):
303
305
  params = self.get_params()
304
306
 
305
307
  x0 = params.to_vec().detach().cpu().numpy()
306
- bounds = [(self._lb, self._ub)] * len(x0)
308
+
309
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
310
+ bounds = []
311
+ for p, l, u in zip(params, lb, ub):
312
+ bounds.extend([(l, u)] * p.numel())
307
313
 
308
314
  res = scipy.optimize.differential_evolution(
309
315
  partial(self._objective, params = params, closure = closure),
@@ -321,7 +327,8 @@ class ScipyDualAnnealing(Optimizer):
321
327
  def __init__(
322
328
  self,
323
329
  params,
324
- bounds: tuple[float, float],
330
+ lb: float,
331
+ ub: float,
325
332
  maxiter=1000,
326
333
  minimizer_kwargs=None,
327
334
  initial_temp=5230.0,
@@ -332,23 +339,25 @@ class ScipyDualAnnealing(Optimizer):
332
339
  rng=None,
333
340
  no_local_search=False,
334
341
  ):
335
- super().__init__(params, {})
342
+ super().__init__(params, lb=lb, ub=ub)
336
343
 
337
344
  kwargs = locals().copy()
338
- del kwargs['self'], kwargs['params'], kwargs['bounds'], kwargs['__class__']
345
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
339
346
  self._kwargs = kwargs
340
- self._lb, self._ub = bounds
341
347
 
342
348
  def _objective(self, x: np.ndarray, params: TensorList, closure):
343
349
  params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
344
350
  return _ensure_float(closure(False))
345
351
 
346
352
  @torch.no_grad
347
- def step(self, closure: Closure):# pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
353
+ def step(self, closure: Closure):
348
354
  params = self.get_params()
349
355
 
350
356
  x0 = params.to_vec().detach().cpu().numpy()
351
- bounds = [(self._lb, self._ub)] * len(x0)
357
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
358
+ bounds = []
359
+ for p, l, u in zip(params, lb, ub):
360
+ bounds.extend([(l, u)] * p.numel())
352
361
 
353
362
  res = scipy.optimize.dual_annealing(
354
363
  partial(self._objective, params = params, closure = closure),
@@ -360,3 +369,145 @@ class ScipyDualAnnealing(Optimizer):
360
369
  params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
361
370
  return res.fun
362
371
 
372
+
373
+
374
+ class ScipySHGO(Optimizer):
375
+ def __init__(
376
+ self,
377
+ params,
378
+ lb: float,
379
+ ub: float,
380
+ constraints = None,
381
+ n: int = 100,
382
+ iters: int = 1,
383
+ callback = None,
384
+ minimizer_kwargs = None,
385
+ options = None,
386
+ sampling_method: str = 'simplicial',
387
+ ):
388
+ super().__init__(params, lb=lb, ub=ub)
389
+
390
+ kwargs = locals().copy()
391
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
392
+ self._kwargs = kwargs
393
+
394
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
395
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
396
+ return _ensure_float(closure(False))
397
+
398
+ @torch.no_grad
399
+ def step(self, closure: Closure):
400
+ params = self.get_params()
401
+
402
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
403
+ bounds = []
404
+ for p, l, u in zip(params, lb, ub):
405
+ bounds.extend([(l, u)] * p.numel())
406
+
407
+ res = scipy.optimize.shgo(
408
+ partial(self._objective, params = params, closure = closure),
409
+ bounds=bounds,
410
+ **self._kwargs
411
+ )
412
+
413
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
414
+ return res.fun
415
+
416
+
417
+ class ScipyDIRECT(Optimizer):
418
+ def __init__(
419
+ self,
420
+ params,
421
+ lb: float,
422
+ ub: float,
423
+ maxfun: int | None = 1000,
424
+ maxiter: int = 1000,
425
+ eps: float = 0.0001,
426
+ locally_biased: bool = True,
427
+ f_min: float = -np.inf,
428
+ f_min_rtol: float = 0.0001,
429
+ vol_tol: float = 1e-16,
430
+ len_tol: float = 0.000001,
431
+ callback = None,
432
+ ):
433
+ super().__init__(params, lb=lb, ub=ub)
434
+
435
+ kwargs = locals().copy()
436
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
437
+ self._kwargs = kwargs
438
+
439
+ def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
440
+ if self.raised: return np.inf
441
+ try:
442
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
443
+ return _ensure_float(closure(False))
444
+ except Exception as e:
445
+ # he he he ha, I found a way to make exceptions work in fcmaes and scipy direct
446
+ self.e = e
447
+ self.raised = True
448
+ return np.inf
449
+
450
+ @torch.no_grad
451
+ def step(self, closure: Closure):
452
+ self.raised = False
453
+ self.e = None
454
+
455
+ params = self.get_params()
456
+
457
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
458
+ bounds = []
459
+ for p, l, u in zip(params, lb, ub):
460
+ bounds.extend([(l, u)] * p.numel())
461
+
462
+ res = scipy.optimize.direct(
463
+ partial(self._objective, params=params, closure=closure),
464
+ bounds=bounds,
465
+ **self._kwargs
466
+ )
467
+
468
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
469
+
470
+ if self.e is not None: raise self.e from None
471
+ return res.fun
472
+
473
+
474
+
475
+
476
+ class ScipyBrute(Optimizer):
477
+ def __init__(
478
+ self,
479
+ params,
480
+ lb: float,
481
+ ub: float,
482
+ Ns: int = 20,
483
+ full_output: int = 0,
484
+ finish = scipy.optimize.fmin,
485
+ disp: bool = False,
486
+ workers: int = 1
487
+ ):
488
+ super().__init__(params, lb=lb, ub=ub)
489
+
490
+ kwargs = locals().copy()
491
+ del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
492
+ self._kwargs = kwargs
493
+
494
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
495
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
496
+ return _ensure_float(closure(False))
497
+
498
+ @torch.no_grad
499
+ def step(self, closure: Closure):
500
+ params = self.get_params()
501
+
502
+ lb, ub = self.group_vals('lb', 'ub', cls=list)
503
+ bounds = []
504
+ for p, l, u in zip(params, lb, ub):
505
+ bounds.extend([(l, u)] * p.numel())
506
+
507
+ x0 = scipy.optimize.brute(
508
+ partial(self._objective, params = params, closure = closure),
509
+ ranges=bounds,
510
+ **self._kwargs
511
+ )
512
+ params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
513
+ return None
@@ -9,11 +9,7 @@ from .optimizer import (
9
9
  get_group_vals,
10
10
  get_params,
11
11
  get_state_vals,
12
- grad_at_params,
13
- grad_vec_at_params,
14
- loss_at_params,
15
- loss_grad_at_params,
16
- loss_grad_vec_at_params,
12
+ unpack_states,
17
13
  )
18
14
  from .params import (
19
15
  Params,
@@ -22,6 +18,6 @@ from .params import (
22
18
  _copy_param_groups,
23
19
  _make_param_groups,
24
20
  )
25
- from .python_tools import flatten, generic_eq, reduce_dim
26
- from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like
21
+ from .python_tools import flatten, generic_eq, generic_ne, reduce_dim, unpack_dicts
22
+ from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like, generic_finfo_eps
27
23
  from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_
@@ -2,6 +2,7 @@ from collections.abc import Iterable, Sequence
2
2
 
3
3
  import torch
4
4
  import torch.autograd.forward_ad as fwAD
5
+ from typing import Literal
5
6
 
6
7
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
7
8
 
@@ -157,7 +158,7 @@ def hessian_mat(
157
158
  method="func",
158
159
  vectorize=False,
159
160
  outer_jacobian_strategy="reverse-mode",
160
- ):
161
+ ) -> torch.Tensor:
161
162
  """
162
163
  returns hessian matrix for parameters (as if they were flattened and concatenated into a vector).
163
164
 
@@ -189,7 +190,7 @@ def hessian_mat(
189
190
  return loss
190
191
 
191
192
  if method == 'func':
192
- return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph))
193
+ return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph)) # pyright:ignore[reportReturnType]
193
194
 
194
195
  if method == 'autograd.functional':
195
196
  return torch.autograd.functional.hessian(
@@ -198,7 +199,7 @@ def hessian_mat(
198
199
  create_graph=create_graph,
199
200
  vectorize=vectorize,
200
201
  outer_jacobian_strategy=outer_jacobian_strategy,
201
- )
202
+ ) # pyright:ignore[reportReturnType]
202
203
  raise ValueError(method)
203
204
 
204
205
  def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
@@ -510,4 +511,4 @@ def hvp_fd_forward(
510
511
  torch._foreach_div_(hvp_, h)
511
512
 
512
513
  if normalize: torch._foreach_mul_(hvp_, vec_norm)
513
- return loss, hvp_
514
+ return loss, hvp_
@@ -2,4 +2,4 @@ from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix
2
2
  from .orthogonalize import gram_schmidt
3
3
  from .qr import qr_householder
4
4
  from .svd import randomized_svd
5
- from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
5
+ from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve, steihaug_toint_cg