torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -4,7 +4,7 @@ from typing import Literal
4
4
  import torch
5
5
 
6
6
  from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
7
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
8
8
 
9
9
 
10
10
  @torch.no_grad
@@ -14,7 +14,7 @@ def weight_decay_(
14
14
  weight_decay: float | NumberList,
15
15
  ord: int = 2
16
16
  ):
17
- """returns `grad_`."""
17
+ """modifies in-place and returns ``grad_``."""
18
18
  if ord == 1: return grad_.add_(params.sign().mul_(weight_decay))
19
19
  if ord == 2: return grad_.add_(params.mul(weight_decay))
20
20
  if ord - 1 % 2 != 0: return grad_.add_(params.pow(ord-1).mul_(weight_decay))
@@ -29,39 +29,38 @@ class WeightDecay(Transform):
29
29
  ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
30
30
  target (Target, optional): what to set on var. Defaults to 'update'.
31
31
 
32
- Examples:
33
- Adam with non-decoupled weight decay
34
-
35
- .. code-block:: python
36
-
37
- opt = tz.Modular(
38
- model.parameters(),
39
- tz.m.WeightDecay(1e-3),
40
- tz.m.Adam(),
41
- tz.m.LR(1e-3)
42
- )
43
-
44
- Adam with decoupled weight decay that still scales with learning rate
45
-
46
- .. code-block:: python
47
-
48
- opt = tz.Modular(
49
- model.parameters(),
50
- tz.m.Adam(),
51
- tz.m.WeightDecay(1e-3),
52
- tz.m.LR(1e-3)
53
- )
54
-
55
- Adam with fully decoupled weight decay that doesn't scale with learning rate
56
-
57
- .. code-block:: python
58
-
59
- opt = tz.Modular(
60
- model.parameters(),
61
- tz.m.Adam(),
62
- tz.m.LR(1e-3),
63
- tz.m.WeightDecay(1e-6)
64
- )
32
+ ### Examples:
33
+
34
+ Adam with non-decoupled weight decay
35
+ ```python
36
+ opt = tz.Modular(
37
+ model.parameters(),
38
+ tz.m.WeightDecay(1e-3),
39
+ tz.m.Adam(),
40
+ tz.m.LR(1e-3)
41
+ )
42
+ ```
43
+
44
+ Adam with decoupled weight decay that still scales with learning rate
45
+ ```python
46
+
47
+ opt = tz.Modular(
48
+ model.parameters(),
49
+ tz.m.Adam(),
50
+ tz.m.WeightDecay(1e-3),
51
+ tz.m.LR(1e-3)
52
+ )
53
+ ```
54
+
55
+ Adam with fully decoupled weight decay that doesn't scale with learning rate
56
+ ```python
57
+ opt = tz.Modular(
58
+ model.parameters(),
59
+ tz.m.Adam(),
60
+ tz.m.LR(1e-3),
61
+ tz.m.WeightDecay(1e-6)
62
+ )
63
+ ```
65
64
 
66
65
  """
67
66
  def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
@@ -77,7 +76,7 @@ class WeightDecay(Transform):
77
76
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
78
77
 
79
78
  class RelativeWeightDecay(Transform):
80
- """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of :code:`norm_input` argument.
79
+ """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
81
80
 
82
81
  Args:
83
82
  weight_decay (float): relative weight decay scale.
@@ -85,40 +84,42 @@ class RelativeWeightDecay(Transform):
85
84
  norm_input (str, optional):
86
85
  determines what should weight decay be relative to. "update", "grad" or "params".
87
86
  Defaults to "update".
87
+ metric (Ords, optional):
88
+ metric (norm, etc) that weight decay should be relative to.
89
+ defaults to 'mad' (mean absolute deviation).
88
90
  target (Target, optional): what to set on var. Defaults to 'update'.
89
91
 
90
- Examples:
91
- Adam with non-decoupled relative weight decay
92
-
93
- .. code-block:: python
94
-
95
- opt = tz.Modular(
96
- model.parameters(),
97
- tz.m.RelativeWeightDecay(1e-3),
98
- tz.m.Adam(),
99
- tz.m.LR(1e-3)
100
- )
101
-
102
- Adam with decoupled relative weight decay
103
-
104
- .. code-block:: python
105
-
106
- opt = tz.Modular(
107
- model.parameters(),
108
- tz.m.Adam(),
109
- tz.m.RelativeWeightDecay(1e-3),
110
- tz.m.LR(1e-3)
111
- )
112
-
92
+ ### Examples:
93
+
94
+ Adam with non-decoupled relative weight decay
95
+ ```python
96
+ opt = tz.Modular(
97
+ model.parameters(),
98
+ tz.m.RelativeWeightDecay(1e-1),
99
+ tz.m.Adam(),
100
+ tz.m.LR(1e-3)
101
+ )
102
+ ```
103
+
104
+ Adam with decoupled relative weight decay
105
+ ```python
106
+ opt = tz.Modular(
107
+ model.parameters(),
108
+ tz.m.Adam(),
109
+ tz.m.RelativeWeightDecay(1e-1),
110
+ tz.m.LR(1e-3)
111
+ )
112
+ ```
113
113
  """
114
114
  def __init__(
115
115
  self,
116
116
  weight_decay: float = 0.1,
117
117
  ord: int = 2,
118
118
  norm_input: Literal["update", "grad", "params"] = "update",
119
+ metric: Metrics = 'mad',
119
120
  target: Target = "update",
120
121
  ):
121
- defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
122
+ defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
122
123
  super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
123
124
 
124
125
  @torch.no_grad
@@ -127,6 +128,7 @@ class RelativeWeightDecay(Transform):
127
128
 
128
129
  ord = settings[0]['ord']
129
130
  norm_input = settings[0]['norm_input']
131
+ metric = settings[0]['metric']
130
132
 
131
133
  if norm_input == 'update': src = TensorList(tensors)
132
134
  elif norm_input == 'grad':
@@ -137,9 +139,8 @@ class RelativeWeightDecay(Transform):
137
139
  else:
138
140
  raise ValueError(norm_input)
139
141
 
140
- mean_abs = src.abs().global_mean()
141
-
142
- return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * mean_abs, ord)
142
+ norm = src.global_metric(metric)
143
+ return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)
143
144
 
144
145
 
145
146
  @torch.no_grad
@@ -162,7 +163,7 @@ class DirectWeightDecay(Module):
162
163
  @torch.no_grad
163
164
  def step(self, var):
164
165
  weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
165
- ord = self.settings[var.params[0]]['ord']
166
+ ord = self.defaults['ord']
166
167
 
167
168
  decay_weights_(var.params, weight_decay, ord)
168
169
  return var
@@ -0,0 +1 @@
1
+ from .cd import CD
@@ -0,0 +1,122 @@
1
+ import math
2
+ import random
3
+ import warnings
4
+ from functools import partial
5
+ from typing import Literal
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from ...core import Module
11
+ from ...utils import NumberList, TensorList
12
+
13
+ class CD(Module):
14
+ """Coordinate descent. Proposes a descent direction along a single coordinate.
15
+ A line search such as ``tz.m.ScipyMinimizeScalar(maxiter=8)`` or a fixed step size can be used after this.
16
+
17
+ Args:
18
+ h (float, optional): finite difference step size. Defaults to 1e-3.
19
+ grad (bool, optional):
20
+ if True, scales direction by gradient estimate. If False, the scale is fixed to 1. Defaults to True.
21
+ adaptive (bool, optional):
22
+ whether to adapt finite difference step size, this requires an additional buffer. Defaults to True.
23
+ index (str, optional):
24
+ index selection strategy.
25
+ - "cyclic" - repeatedly cycles through each coordinate, e.g. ``1,2,3,1,2,3,...``.
26
+ - "cyclic2" - cycles forward and then backward, e.g ``1,2,3,3,2,1,1,2,3,...`` (default).
27
+ - "random" - picks coordinate randomly.
28
+ threepoint (bool, optional):
29
+ whether to use three points (three function evaluatins) to determine descent direction.
30
+ if False, uses two points, but then ``adaptive`` can't be used. Defaults to True.
31
+ """
32
+ def __init__(self, h:float=1e-3, grad:bool=True, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
33
+ defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
34
+ super().__init__(defaults)
35
+
36
+ @torch.no_grad
37
+ def step(self, var):
38
+ closure = var.closure
39
+ if closure is None:
40
+ raise RuntimeError("CD requires closure")
41
+
42
+ params = TensorList(var.params)
43
+ ndim = params.global_numel()
44
+
45
+ grad_step_size = self.defaults['grad']
46
+ adaptive = self.defaults['adaptive']
47
+ index_strategy = self.defaults['index']
48
+ h = self.defaults['h']
49
+ threepoint = self.defaults['threepoint']
50
+
51
+ # ------------------------------ determine index ----------------------------- #
52
+ if index_strategy == 'cyclic':
53
+ idx = self.global_state.get('idx', 0) % ndim
54
+ self.global_state['idx'] = idx + 1
55
+
56
+ elif index_strategy == 'cyclic2':
57
+ idx = self.global_state.get('idx', 0)
58
+ self.global_state['idx'] = idx + 1
59
+ if idx >= ndim * 2:
60
+ idx = self.global_state['idx'] = 0
61
+ if idx >= ndim:
62
+ idx = (2*ndim - idx) - 1
63
+
64
+ elif index_strategy == 'random':
65
+ if 'generator' not in self.global_state:
66
+ self.global_state['generator'] = random.Random(0)
67
+ generator = self.global_state['generator']
68
+ idx = generator.randrange(0, ndim)
69
+
70
+ else:
71
+ raise ValueError(index_strategy)
72
+
73
+ # -------------------------- find descent direction -------------------------- #
74
+ h_vec = None
75
+ if adaptive:
76
+ if threepoint:
77
+ h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, h), cls=TensorList)
78
+ h = float(h_vec.flat_get(idx))
79
+ else:
80
+ warnings.warn("CD adaptive=True only works with threepoint=True")
81
+
82
+ f_0 = var.get_loss(False)
83
+ params.flat_set_lambda_(idx, lambda x: x + h)
84
+ f_p = closure(False)
85
+
86
+ # -------------------------------- threepoint -------------------------------- #
87
+ if threepoint:
88
+ params.flat_set_lambda_(idx, lambda x: x - 2*h)
89
+ f_n = closure(False)
90
+ params.flat_set_lambda_(idx, lambda x: x + h)
91
+
92
+ if adaptive:
93
+ assert h_vec is not None
94
+ if f_0 <= f_p and f_0 <= f_n:
95
+ h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
96
+ else:
97
+ if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
98
+ h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))
99
+
100
+ if grad_step_size:
101
+ alpha = (f_p - f_n) / (2*h)
102
+
103
+ else:
104
+ if f_0 < f_p and f_0 < f_n: alpha = 0
105
+ elif f_p < f_n: alpha = -1
106
+ else: alpha = 1
107
+
108
+ # --------------------------------- twopoint --------------------------------- #
109
+ else:
110
+ params.flat_set_lambda_(idx, lambda x: x - h)
111
+ if grad_step_size:
112
+ alpha = (f_p - f_0) / h
113
+ else:
114
+ if f_p < f_0: alpha = -1
115
+ else: alpha = 1
116
+
117
+ # ----------------------------- create the update ---------------------------- #
118
+ update = params.zeros_like()
119
+ update.flat_set_(idx, alpha)
120
+ var.update = update
121
+ return var
122
+
@@ -0,0 +1,65 @@
1
+ """WIP, untested"""
2
+ from collections.abc import Callable
3
+
4
+ from abc import abstractmethod
5
+ import torch
6
+ from ..modules.higher_order.multipoint import sixth_order_im1, sixth_order_p6, _solve
7
+
8
+ def make_evaluate(f: Callable[[torch.Tensor], torch.Tensor]):
9
+ def evaluate(x, order) -> tuple[torch.Tensor, ...]:
10
+ """order=0 - returns (f,), order=1 - returns (f, J), order=2 - returns (f, J, H), etc."""
11
+ n = x.numel()
12
+
13
+ if order == 0:
14
+ f_x = f(x)
15
+ return (f_x, )
16
+
17
+ x.requires_grad_()
18
+ with torch.enable_grad():
19
+ f_x = f(x)
20
+ I = torch.eye(n, device=x.device, dtype=x.dtype),
21
+ g_x = torch.autograd.grad(f_x, x, I, create_graph=order!=1, is_grads_batched=True)[0]
22
+ ret = [f_x, g_x]
23
+ T = g_x
24
+
25
+ # get all derivative up to order
26
+ for o in range(2, order + 1):
27
+ is_last = o == order
28
+ I = torch.eye(T.numel(), device=x.device, dtype=x.dtype),
29
+ T = torch.autograd.grad(T.ravel(), x, I, create_graph=not is_last, is_grads_batched=True)[0]
30
+ ret.append(T.view(n, n, *T.shape[1:]))
31
+
32
+ return tuple(ret)
33
+
34
+ return evaluate
35
+
36
+ class RootBase:
37
+ @abstractmethod
38
+ def one_iteration(
39
+ self,
40
+ x: torch.Tensor,
41
+ evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
42
+ ) -> torch.Tensor:
43
+ """"""
44
+
45
+
46
+ # ---------------------------------- methods --------------------------------- #
47
+ def newton(x:torch.Tensor, f_j, lstsq:bool=False):
48
+ f_x, G_x = f_j(x)
49
+ return x - _solve(G_x, f_x, lstsq=lstsq)
50
+
51
+ class Newton(RootBase):
52
+ def __init__(self, lstsq: bool=False): self.lstsq = lstsq
53
+ def one_iteration(self, x, evaluate): return newton(x, evaluate, self.lstsq)
54
+
55
+
56
+ class SixthOrderP6(RootBase):
57
+ """sixth-order iterative method
58
+
59
+ Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
60
+ """
61
+ def __init__(self, lstsq: bool=False): self.lstsq = lstsq
62
+ def one_iteration(self, x, evaluate):
63
+ def f(x): return evaluate(x, 0)[0]
64
+ def f_j(x): return evaluate(x, 1)
65
+ return sixth_order_p6(x, f, f_j, self.lstsq)
@@ -11,12 +11,12 @@ class Split(torch.optim.Optimizer):
11
11
 
12
12
  Example:
13
13
 
14
- .. code:: py
15
-
16
- opt = Split(
17
- torch.optim.Adam(model.encoder.parameters(), lr=0.001),
18
- torch.optim.SGD(model.decoder.parameters(), lr=0.1)
19
- )
14
+ ```python
15
+ opt = Split(
16
+ torch.optim.Adam(model.encoder.parameters(), lr=0.001),
17
+ torch.optim.SGD(model.decoder.parameters(), lr=0.1)
18
+ )
19
+ ```
20
20
  """
21
21
  def __init__(self, *optimizers: torch.optim.Optimizer | Iterable[torch.optim.Optimizer]):
22
22
  all_params = []
@@ -25,14 +25,14 @@ class Split(torch.optim.Optimizer):
25
25
  # gather all params in case user tries to access them from this object
26
26
  for i,opt in enumerate(self.optimizers):
27
27
  for p in get_params(opt.param_groups, 'all', list):
28
- if p not in all_params: all_params.append(p)
28
+ if id(p) not in [id(pr) for pr in all_params]: all_params.append(p)
29
29
  else: warnings.warn(
30
30
  f'optimizers[{i}] {opt.__class__.__name__} has some duplicate parameters '
31
31
  'that are also in previous optimizers. They will be updated multiple times.')
32
32
 
33
33
  super().__init__(all_params, {})
34
34
 
35
- def step(self, closure: Callable | None = None):
35
+ def step(self, closure: Callable | None = None): # pyright:ignore[reportIncompatibleMethodOverride]
36
36
  loss = None
37
37
 
38
38
  # if closure provided, populate grad, otherwise each optimizer will call closure separately
@@ -7,7 +7,6 @@ import numpy as np
7
7
  import torch
8
8
  from directsearch.ds import DEFAULT_PARAMS
9
9
 
10
- from ...modules.second_order.newton import tikhonov_
11
10
  from ...utils import Optimizer, TensorList
12
11
 
13
12
 
@@ -2,11 +2,12 @@ from collections.abc import Callable
2
2
  from functools import partial
3
3
  from typing import Any, Literal
4
4
 
5
+ import numpy as np
6
+ import torch
7
+
5
8
  import fcmaes
6
9
  import fcmaes.optimizer
7
10
  import fcmaes.retry
8
- import numpy as np
9
- import torch
10
11
 
11
12
  from ...utils import Optimizer, TensorList
12
13
 
@@ -75,8 +75,6 @@ class NLOptWrapper(Optimizer):
75
75
  so usually you would want to perform a single step, although performing multiple steps will refine the
76
76
  solution.
77
77
 
78
- Some algorithms are buggy with numpy>=2.
79
-
80
78
  Args:
81
79
  params: iterable of parameters to optimize or dicts defining parameter groups.
82
80
  algorithm (int | _ALGOS_LITERAL): optimization algorithm from https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  import optuna
8
8
 
9
- from ...utils import Optimizer
9
+ from ...utils import Optimizer, totensor, tofloat
10
10
 
11
11
  def silence_optuna():
12
12
  optuna.logging.set_verbosity(optuna.logging.WARNING)
@@ -65,6 +65,6 @@ class OptunaSampler(Optimizer):
65
65
  params.from_vec_(vec)
66
66
 
67
67
  loss = closure()
68
- with torch.enable_grad(): self.study.tell(trial, loss)
68
+ with torch.enable_grad(): self.study.tell(trial, tofloat(torch.nan_to_num(totensor(loss), 1e32)))
69
69
 
70
70
  return loss
@@ -4,12 +4,17 @@ from functools import partial
4
4
  from typing import Any, Literal
5
5
 
6
6
  import numpy as np
7
- import scipy.optimize
8
7
  import torch
9
8
 
9
+ import scipy.optimize
10
+
10
11
  from ...utils import Optimizer, TensorList
11
- from ...utils.derivatives import jacobian_and_hessian_mat_wrt, jacobian_wrt
12
- from ...modules.second_order.newton import tikhonov_
12
+ from ...utils.derivatives import (
13
+ flatten_jacobian,
14
+ jacobian_and_hessian_mat_wrt,
15
+ jacobian_wrt,
16
+ )
17
+
13
18
 
14
19
  def _ensure_float(x) -> float:
15
20
  if isinstance(x, torch.Tensor): return x.detach().cpu().item()
@@ -21,14 +26,6 @@ def _ensure_numpy(x):
21
26
  if isinstance(x, np.ndarray): return x
22
27
  return np.array(x)
23
28
 
24
- def matrix_clamp(H: torch.Tensor, reg: float):
25
- try:
26
- eigvals, eigvecs = torch.linalg.eigh(H) # pylint:disable=not-callable
27
- eigvals.clamp_(min=reg)
28
- return eigvecs @ torch.diag(eigvals) @ eigvecs.mH
29
- except Exception:
30
- return H
31
-
32
29
  Closure = Callable[[bool], Any]
33
30
 
34
31
  class ScipyMinimize(Optimizer):
@@ -76,8 +73,6 @@ class ScipyMinimize(Optimizer):
76
73
  options = None,
77
74
  jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
78
75
  hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
79
- tikhonov: float | None = 0,
80
- min_eigval: float | None = None,
81
76
  ):
82
77
  defaults = dict(lb=lb, ub=ub)
83
78
  super().__init__(params, defaults)
@@ -85,12 +80,10 @@ class ScipyMinimize(Optimizer):
85
80
  self.constraints = constraints
86
81
  self.tol = tol
87
82
  self.callback = callback
88
- self.min_eigval = min_eigval
89
83
  self.options = options
90
84
 
91
85
  self.jac = jac
92
86
  self.hess = hess
93
- self.tikhonov: float | None = tikhonov
94
87
 
95
88
  self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
96
89
  'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
@@ -111,9 +104,7 @@ class ScipyMinimize(Optimizer):
111
104
  with torch.enable_grad():
112
105
  value = closure(False)
113
106
  _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
114
- if self.tikhonov is not None: H = tikhonov_(H, self.tikhonov)
115
- if self.min_eigval is not None: H = matrix_clamp(H, self.min_eigval)
116
- return H.detach().cpu().numpy()
107
+ return H.numpy(force=True)
117
108
 
118
109
  def _objective(self, x: np.ndarray, params: TensorList, closure):
119
110
  # set params to x
@@ -122,7 +113,10 @@ class ScipyMinimize(Optimizer):
122
113
  # return value and maybe gradients
123
114
  if self.use_jac_autograd:
124
115
  with torch.enable_grad(): value = _ensure_float(closure())
125
- return value, params.ensure_grad_().grad.to_vec().detach().cpu().numpy()
116
+ grad = params.ensure_grad_().grad.to_vec().numpy(force=True)
117
+ # slsqp requires float64
118
+ if self.method.lower() == 'slsqp': grad = grad.astype(np.float64)
119
+ return value, grad
126
120
  return _ensure_float(closure(False))
127
121
 
128
122
  @torch.no_grad
@@ -135,7 +129,7 @@ class ScipyMinimize(Optimizer):
135
129
  else: hess = None
136
130
  else: hess = self.hess
137
131
 
138
- x0 = params.to_vec().detach().cpu().numpy()
132
+ x0 = params.to_vec().numpy(force=True)
139
133
 
140
134
  # make bounds
141
135
  lb, ub = self.group_vals('lb', 'ub', cls=list)
@@ -167,7 +161,7 @@ class ScipyMinimize(Optimizer):
167
161
 
168
162
 
169
163
  class ScipyRootOptimization(Optimizer):
170
- """Optimization via using scipy.root on gradients, mainly for experimenting!
164
+ """Optimization via using scipy.optimize.root on gradients, mainly for experimenting!
171
165
 
172
166
  Args:
173
167
  params: iterable of parameters to optimize or dicts defining parameter groups.
@@ -248,6 +242,72 @@ class ScipyRootOptimization(Optimizer):
248
242
  return res.fun
249
243
 
250
244
 
245
+ class ScipyLeastSquaresOptimization(Optimizer):
246
+ """Optimization via using scipy.optimize.least_squares on gradients, mainly for experimenting!
247
+
248
+ Args:
249
+ params: iterable of parameters to optimize or dicts defining parameter groups.
250
+ method (str | None, optional): _description_. Defaults to None.
251
+ tol (float | None, optional): _description_. Defaults to None.
252
+ callback (_type_, optional): _description_. Defaults to None.
253
+ options (_type_, optional): _description_. Defaults to None.
254
+ jac (T.Literal[&#39;2, optional): _description_. Defaults to 'autograd'.
255
+ """
256
+ def __init__(
257
+ self,
258
+ params,
259
+ method='trf',
260
+ jac='autograd',
261
+ bounds=(-np.inf, np.inf),
262
+ ftol=1e-8, xtol=1e-8, gtol=1e-8, x_scale=1.0, loss='linear',
263
+ f_scale=1.0, diff_step=None, tr_solver=None, tr_options=None,
264
+ jac_sparsity=None, max_nfev=None, verbose=0
265
+ ):
266
+ super().__init__(params, {})
267
+ kwargs = locals().copy()
268
+ del kwargs['self'], kwargs['params'], kwargs['__class__'], kwargs['jac']
269
+ self._kwargs = kwargs
270
+
271
+ self.jac = jac
272
+
273
+
274
+ def _objective(self, x: np.ndarray, params: TensorList, closure):
275
+ # set params to x
276
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
277
+
278
+ # return the gradients
279
+ with torch.enable_grad(): self.value = closure()
280
+ jac = params.ensure_grad_().grad.to_vec()
281
+ return jac.numpy(force=True)
282
+
283
+ def _hess(self, x: np.ndarray, params: TensorList, closure):
284
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
285
+ with torch.enable_grad():
286
+ value = closure(False)
287
+ _, H = jacobian_and_hessian_mat_wrt([value], wrt = params)
288
+ return H.numpy(force=True)
289
+
290
+ @torch.no_grad
291
+ def step(self, closure: Closure): # pylint:disable = signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
292
+ params = self.get_params()
293
+
294
+ x0 = params.to_vec().detach().cpu().numpy()
295
+
296
+ if self.jac == 'autograd': jac = partial(self._hess, params = params, closure = closure)
297
+ else: jac = self.jac
298
+
299
+ res = scipy.optimize.least_squares(
300
+ partial(self._objective, params = params, closure = closure),
301
+ x0 = x0,
302
+ jac=jac, # type:ignore
303
+ **self._kwargs
304
+ )
305
+
306
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
307
+ return res.fun
308
+
309
+
310
+
251
311
 
252
312
  class ScipyDE(Optimizer):
253
313
  """Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
@@ -510,4 +570,3 @@ class ScipyBrute(Optimizer):
510
570
  **self._kwargs
511
571
  )
512
572
  params.from_vec_(torch.from_numpy(x0).to(device = params[0].device, dtype=params[0].dtype, copy=False))
513
- return None