torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,10 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, TensorwiseTransform, Transform
8
+ from ...core import Chainable, Module, TensorTransform, Transform
9
9
  from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
10
- from ...utils.linalg import linear_operator
11
- from ..functional import initial_step_size, safe_clip
10
+ from ...linalg import linear_operator
11
+ from ..opt_utils import initial_step_size, safe_clip
12
12
 
13
13
 
14
14
 
@@ -17,7 +17,7 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
17
17
  elif state[key].shape != value.shape: state[key] = value
18
18
  else: state[key].lerp_(value, 1-beta)
19
19
 
20
- class HessianUpdateStrategy(TensorwiseTransform, ABC):
20
+ class HessianUpdateStrategy(TensorTransform, ABC):
21
21
  """Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
22
22
 
23
23
  This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
@@ -106,11 +106,12 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
106
106
  scale_first: bool = False,
107
107
  concat_params: bool = True,
108
108
  inverse: bool = True,
109
+ uses_loss: bool = False,
109
110
  inner: Chainable | None = None,
110
111
  ):
111
112
  if defaults is None: defaults = {}
112
113
  safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
113
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
114
+ super().__init__(defaults, uses_loss=uses_loss, concat_params=concat_params, update_freq=update_freq, inner=inner)
114
115
 
115
116
  def reset_for_online(self):
116
117
  super().reset_for_online()
@@ -141,23 +142,27 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
141
142
  return H
142
143
 
143
144
  # ------------------------------ common methods ------------------------------ #
144
- def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
145
+ def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float | None:
145
146
  """returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
146
147
  ys = y.dot(s)
147
148
  yy = y.dot(y)
148
- if ys != 0 and yy != 0: return yy/ys
149
- return 1
149
+ tiny = torch.finfo(ys.dtype).tiny * 2
150
+ if ys > tiny and yy > tiny: return yy/ys
151
+ return None
150
152
 
151
- def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
153
+ def reset_P(self, P: torch.Tensor, s:torch.Tensor, y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
152
154
  """resets ``P`` which is either B or H"""
153
155
  set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
154
- if init_scale == 'auto': init_scale = self.auto_initial_scale(s,y)
155
- if init_scale >= 1:
156
+ if init_scale == 'auto':
157
+ init_scale = self.auto_initial_scale(s,y)
158
+ state["scaled"] = init_scale is not None
159
+
160
+ if init_scale is not None and init_scale != 1:
156
161
  if inverse: P /= init_scale
157
162
  else: P *= init_scale
158
163
 
159
164
  @torch.no_grad
160
- def update_tensor(self, tensor, param, grad, loss, state, setting):
165
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
161
166
  p = param.view(-1); g = tensor.view(-1)
162
167
  inverse = setting['inverse']
163
168
  M_key = 'H' if inverse else 'B'
@@ -182,6 +187,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
182
187
  state['f_prev'] = loss
183
188
  state['p_prev'] = p.clone()
184
189
  state['g_prev'] = g.clone()
190
+ state["scaled"] = False
185
191
  return
186
192
 
187
193
  state['f'] = loss
@@ -205,9 +211,13 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
205
211
  if gtol is not None and y.abs().max() <= gtol:
206
212
  return
207
213
 
208
- if step == 2 and init_scale == 'auto':
209
- if inverse: M /= self.auto_initial_scale(s,y)
210
- else: M *= self.auto_initial_scale(s,y)
214
+ # apply automatic initial scale if it hasn't been applied
215
+ if (not state["scaled"]) and (init_scale == 'auto'):
216
+ scale = self.auto_initial_scale(s,y)
217
+ if scale is not None:
218
+ state["scaled"] = True
219
+ if inverse: M /= self.auto_initial_scale(s,y)
220
+ else: M *= self.auto_initial_scale(s,y)
211
221
 
212
222
  beta = setting['beta']
213
223
  if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
@@ -223,7 +233,7 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
223
233
  state['f_prev'] = loss
224
234
 
225
235
  @torch.no_grad
226
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
236
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
227
237
  step = state['step']
228
238
 
229
239
  if setting['scale_first'] and step == 1:
@@ -250,8 +260,8 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
250
260
  self.global_state.clear()
251
261
  return tensor.mul_(initial_step_size(tensor))
252
262
 
253
- def get_H(self, var):
254
- param = var.params[0]
263
+ def get_H(self, objective):
264
+ param = objective.params[0]
255
265
  state = self.state[param]
256
266
  settings = self.settings[param]
257
267
  if "B" in state:
@@ -367,22 +377,21 @@ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
367
377
  B += term1.sub_(term2)
368
378
  return B
369
379
 
370
- def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
380
+
381
+ def bfgs_H_(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol: float):
371
382
  sy = s.dot(y)
372
383
  if sy <= tol: return H
373
384
 
374
- sy_sq = safe_clip(sy**2)
375
-
376
- Hy = H@y
377
- scale1 = (sy + y.dot(Hy)) / sy_sq
378
- term1 = s.outer(s).mul_(scale1)
385
+ rho = 1.0 / sy
386
+ Hy = H @ y
379
387
 
380
- num2 = (Hy.outer(s)).add_(s.outer(y @ H))
381
- term2 = num2.div_(sy)
388
+ term1 = (s.outer(s)).mul_(rho * (1 + rho * y.dot(Hy)))
389
+ term2 = (Hy.outer(s) + s.outer(Hy)).mul_(rho)
382
390
 
383
- H += term1.sub_(term2)
391
+ H.add_(term1).sub_(term2)
384
392
  return H
385
393
 
394
+
386
395
  class BFGS(_InverseHessianUpdateStrategyDefaults):
387
396
  """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
388
397
 
@@ -428,7 +437,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
428
437
  BFGS with backtracking line search:
429
438
 
430
439
  ```python
431
- opt = tz.Modular(
440
+ opt = tz.Optimizer(
432
441
  model.parameters(),
433
442
  tz.m.BFGS(),
434
443
  tz.m.Backtracking()
@@ -437,7 +446,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
437
446
 
438
447
  BFGS with trust region
439
448
  ```python
440
- opt = tz.Modular(
449
+ opt = tz.Optimizer(
441
450
  model.parameters(),
442
451
  tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
443
452
  )
@@ -505,7 +514,7 @@ class SR1(_InverseHessianUpdateStrategyDefaults):
505
514
 
506
515
  SR1 with trust region
507
516
  ```python
508
- opt = tz.Modular(
517
+ opt = tz.Optimizer(
509
518
  model.parameters(),
510
519
  tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
511
520
  )
@@ -1005,7 +1014,7 @@ def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
1005
1014
  return g - (y * (s.dot(g) / sy))
1006
1015
 
1007
1016
 
1008
- class GradientCorrection(Transform):
1017
+ class GradientCorrection(TensorTransform):
1009
1018
  """
1010
1019
  Estimates gradient at minima along search direction assuming function is quadratic.
1011
1020
 
@@ -1015,7 +1024,7 @@ class GradientCorrection(Transform):
1015
1024
  L-BFGS with gradient correction
1016
1025
 
1017
1026
  ```python
1018
- opt = tz.Modular(
1027
+ opt = tz.Optimizer(
1019
1028
  model.parameters(),
1020
1029
  tz.m.LBFGS(inner=tz.m.GradientCorrection()),
1021
1030
  tz.m.Backtracking()
@@ -1027,9 +1036,9 @@ class GradientCorrection(Transform):
1027
1036
 
1028
1037
  """
1029
1038
  def __init__(self):
1030
- super().__init__(None, uses_grad=False)
1039
+ super().__init__()
1031
1040
 
1032
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
1041
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
1033
1042
  if 'p_prev' not in states[0]:
1034
1043
  p_prev = unpack_states(states, tensors, 'p_prev', init=params)
1035
1044
  g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
@@ -1154,6 +1163,7 @@ class NewSSM(HessianUpdateStrategy):
1154
1163
  scale_first=scale_first,
1155
1164
  concat_params=concat_params,
1156
1165
  inverse=True,
1166
+ uses_loss=True,
1157
1167
  inner=inner,
1158
1168
  )
1159
1169
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
@@ -1171,13 +1181,18 @@ class NewSSM(HessianUpdateStrategy):
1171
1181
 
1172
1182
  # this is supposed to be equivalent (and it is)
1173
1183
  def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
1174
- p = H@y
1175
- #(1-y)^2 (ppT)/(pTq)
1176
- #term = p.outer(p).div_(p.dot(y).clip(min=1e-32))
1177
- term = p.outer(p).div_(safe_clip(p.dot(y)))
1178
- H.sub_(term, alpha=1-alpha**2)
1184
+ Hy = H @ y
1185
+ yHy = safe_clip(y.dot(Hy))
1186
+ term = Hy.outer(Hy).div_(yHy)
1187
+ H.sub_(term, alpha=(1-alpha**2))
1179
1188
  return H
1180
1189
 
1190
+ # def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
1191
+ # Hy = H @ y
1192
+ # yHy = safe_clip(y.dot(Hy))
1193
+ # H -= (Hy.outer(y) @ H).div_(yHy)
1194
+ # return H
1195
+
1181
1196
  class ShorR(HessianUpdateStrategy):
1182
1197
  """Shor’s r-algorithm.
1183
1198
 
@@ -1,29 +1,39 @@
1
1
  import torch
2
2
 
3
- from ...core import Module, Chainable, apply_transform
4
- from ...utils import TensorList, vec_to_tensors
5
- from ..second_order.newton import _newton_step, _get_H
3
+ from ...core import Chainable, Transform
4
+ from ...utils import TensorList, unpack_dicts, unpack_states, vec_to_tensors_
5
+ from ...linalg.linear_operator import Dense
6
+
6
7
 
7
8
  def sg2_(
8
9
  delta_g: torch.Tensor,
9
10
  cd: torch.Tensor,
10
11
  ) -> torch.Tensor:
11
- """cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
12
- (or divide delta_g by two)."""
12
+ """cd is c * perturbation."""
13
13
 
14
- M = torch.outer(1.0 / cd, delta_g)
14
+ M = torch.outer(0.5 / cd, delta_g)
15
15
  H_hat = 0.5 * (M + M.T)
16
16
 
17
17
  return H_hat
18
18
 
19
19
 
20
20
 
21
- class SG2(Module):
21
+ class SG2(Transform):
22
22
  """second-order stochastic gradient
23
23
 
24
+ 2SPSA (second-order SPSA)
25
+ ```python
26
+ opt = tz.Optimizer(
27
+ model.parameters(),
28
+ tz.m.SPSA(),
29
+ tz.m.SG2(),
30
+ tz.m.LR(1e-2),
31
+ )
32
+ ```
33
+
24
34
  SG2 with line search
25
35
  ```python
26
- opt = tz.Modular(
36
+ opt = tz.Optimizer(
27
37
  model.parameters(),
28
38
  tz.m.SG2(),
29
39
  tz.m.Backtracking()
@@ -32,9 +42,9 @@ class SG2(Module):
32
42
 
33
43
  SG2 with trust region
34
44
  ```python
35
- opt = tz.Modular(
45
+ opt = tz.Optimizer(
36
46
  model.parameters(),
37
- tz.m.LevenbergMarquardt(tz.m.SG2()),
47
+ tz.m.LevenbergMarquardt(tz.m.SG2(beta=0.75. n_samples=4)),
38
48
  )
39
49
  ```
40
50
 
@@ -43,61 +53,51 @@ class SG2(Module):
43
53
  def __init__(
44
54
  self,
45
55
  n_samples: int = 1,
46
- h: float = 1e-2,
56
+ n_first_step_samples: int = 10,
57
+ start_step: int = 10,
47
58
  beta: float | None = None,
48
- damping: float = 0,
49
- eigval_fn=None,
50
- one_sided: bool = False, # one-sided hessian
51
- use_lstsq: bool = True,
59
+ damping: float = 1e-4,
60
+ h: float = 1e-2,
52
61
  seed=None,
62
+ update_freq: int = 1,
53
63
  inner: Chainable | None = None,
54
64
  ):
55
- defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
56
- super().__init__(defaults)
57
-
58
- if inner is not None: self.set_child('inner', inner)
65
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, seed=seed, start_step=start_step, n_first_step_samples=n_first_step_samples)
66
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
59
67
 
60
68
  @torch.no_grad
61
- def update(self, var):
62
- k = self.global_state.get('step', 0) + 1
63
- self.global_state["step"] = k
69
+ def update_states(self, objective, states, settings):
70
+ fs = settings[0]
71
+ k = self.increment_counter("step", 0)
64
72
 
65
- params = TensorList(var.params)
66
- closure = var.closure
73
+ params = TensorList(objective.params)
74
+ closure = objective.closure
67
75
  if closure is None:
68
76
  raise RuntimeError("closure is required for SG2")
69
77
  generator = self.get_generator(params[0].device, self.defaults["seed"])
70
78
 
71
- h = self.get_settings(params, "h")
79
+ h = unpack_dicts(settings, "h")
72
80
  x_0 = params.clone()
73
- n_samples = self.defaults["n_samples"]
81
+ n_samples = fs["n_samples"]
82
+ if k == 0: n_samples = fs["n_first_step_samples"]
74
83
  H_hat = None
75
84
 
85
+ # compute new approximation
76
86
  for i in range(n_samples):
77
87
  # generate perturbation
78
88
  cd = params.rademacher_like(generator=generator).mul_(h)
79
89
 
80
- # one sided
81
- if self.defaults["one_sided"]:
82
- g_0 = TensorList(var.get_grad())
83
- params.add_(cd)
84
- closure()
90
+ # two sided hessian approximation
91
+ params.add_(cd)
92
+ closure()
93
+ g_p = params.grad.fill_none_(params)
85
94
 
86
- g_p = params.grad.fill_none_(params)
87
- delta_g = (g_p - g_0) * 2
95
+ params.copy_(x_0)
96
+ params.sub_(cd)
97
+ closure()
98
+ g_n = params.grad.fill_none_(params)
88
99
 
89
- # two sided
90
- else:
91
- params.add_(cd)
92
- closure()
93
- g_p = params.grad.fill_none_(params)
94
-
95
- params.copy_(x_0)
96
- params.sub_(cd)
97
- closure()
98
- g_n = params.grad.fill_none_(params)
99
-
100
- delta_g = g_p - g_n
100
+ delta_g = g_p - g_n
101
101
 
102
102
  # restore params
103
103
  params.set_(x_0)
@@ -114,179 +114,43 @@ class SG2(Module):
114
114
  assert H_hat is not None
115
115
  if n_samples > 1: H_hat /= n_samples
116
116
 
117
+ # add damping
118
+ if fs["damping"] != 0:
119
+ reg = torch.eye(H_hat.size(0), device=H_hat.device, dtype=H_hat.dtype).mul_(fs["damping"])
120
+ H_hat += reg
121
+
117
122
  # update H
118
123
  H = self.global_state.get("H", None)
119
124
  if H is None: H = H_hat
120
125
  else:
121
- beta = self.defaults["beta"]
122
- if beta is None: beta = k / (k+1)
126
+ beta = fs["beta"]
127
+ if beta is None: beta = (k+1) / (k+2)
123
128
  H.lerp_(H_hat, 1-beta)
124
129
 
125
130
  self.global_state["H"] = H
126
131
 
127
132
 
128
133
  @torch.no_grad
129
- def apply(self, var):
130
- dir = _newton_step(
131
- var=var,
132
- H = self.global_state["H"],
133
- damping = self.defaults["damping"],
134
- inner = self.children.get("inner", None),
135
- H_tfm=None,
136
- eigval_fn=self.defaults["eigval_fn"],
137
- use_lstsq=self.defaults["use_lstsq"],
138
- g_proj=None,
139
- )
140
-
141
- var.update = vec_to_tensors(dir, var.params)
142
- return var
143
-
144
- def get_H(self,var=...):
145
- return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
146
-
147
-
148
-
149
-
150
- # two sided
151
- # we have g via x + d, x - d
152
- # H via g(x + d), g(x - d)
153
- # 1 is x, x+2d
154
- # 2 is x, x-2d
155
- # 5 evals in total
156
-
157
- # one sided
158
- # g via x, x + d
159
- # 1 is x, x + d
160
- # 2 is x, x - d
161
- # 3 evals and can use two sided for g_0
162
-
163
- class SPSA2(Module):
164
- """second-order SPSA
165
-
166
- SPSA2 with line search
167
- ```python
168
- opt = tz.Modular(
169
- model.parameters(),
170
- tz.m.SPSA2(),
171
- tz.m.Backtracking()
172
- )
173
- ```
174
-
175
- SPSA2 with trust region
176
- ```python
177
- opt = tz.Modular(
178
- model.parameters(),
179
- tz.m.LevenbergMarquardt(tz.m.SPSA2()),
180
- )
181
- ```
182
- """
183
-
184
- def __init__(
185
- self,
186
- n_samples: int = 1,
187
- h: float = 1e-2,
188
- beta: float | None = None,
189
- damping: float = 0,
190
- eigval_fn=None,
191
- use_lstsq: bool = True,
192
- seed=None,
193
- inner: Chainable | None = None,
194
- ):
195
- defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
196
- super().__init__(defaults)
197
-
198
- if inner is not None: self.set_child('inner', inner)
199
-
200
- @torch.no_grad
201
- def update(self, var):
202
- k = self.global_state.get('step', 0) + 1
203
- self.global_state["step"] = k
134
+ def apply_states(self, objective, states, settings):
135
+ fs = settings[0]
136
+ updates = objective.get_updates()
204
137
 
205
- params = TensorList(var.params)
206
- closure = var.closure
207
- if closure is None:
208
- raise RuntimeError("closure is required for SPSA2")
138
+ H: torch.Tensor = self.global_state["H"]
139
+ k = self.global_state["step"]
140
+ if k < fs["start_step"]:
141
+ # don't precondition yet
142
+ # I guess we can try using trace to scale the update
143
+ # because it will have horrible scaling otherwise
144
+ torch._foreach_div_(updates, H.trace())
145
+ return objective
209
146
 
210
- generator = self.get_generator(params[0].device, self.defaults["seed"])
147
+ b = torch.cat([t.ravel() for t in updates])
148
+ sol = torch.linalg.lstsq(H, b).solution # pylint:disable=not-callable
211
149
 
212
- h = self.get_settings(params, "h")
213
- x_0 = params.clone()
214
- n_samples = self.defaults["n_samples"]
215
- H_hat = None
216
- g_0 = None
217
-
218
- for i in range(n_samples):
219
- # perturbations for g and H
220
- cd_g = params.rademacher_like(generator=generator).mul_(h)
221
- cd_H = params.rademacher_like(generator=generator).mul_(h)
222
-
223
- # evaluate 4 points
224
- x_p = x_0 + cd_g
225
- x_n = x_0 - cd_g
150
+ vec_to_tensors_(sol, updates)
151
+ return objective
226
152
 
227
- params.set_(x_p)
228
- f_p = closure(False)
229
- params.add_(cd_H)
230
- f_pp = closure(False)
153
+ def get_H(self, objective=...):
154
+ return Dense(self.global_state["H"])
231
155
 
232
- params.set_(x_n)
233
- f_n = closure(False)
234
- params.add_(cd_H)
235
- f_np = closure(False)
236
-
237
- g_p_vec = (f_pp - f_p) / cd_H
238
- g_n_vec = (f_np - f_n) / cd_H
239
- delta_g = g_p_vec - g_n_vec
240
-
241
- # restore params
242
- params.set_(x_0)
243
156
 
244
- # compute grad
245
- g_i = (f_p - f_n) / (2 * cd_g)
246
- if g_0 is None: g_0 = g_i
247
- else: g_0 += g_i
248
-
249
- # compute H hat
250
- H_i = sg2_(
251
- delta_g = delta_g.to_vec().div_(2.0),
252
- cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
253
- )
254
- if H_hat is None: H_hat = H_i
255
- else: H_hat += H_i
256
-
257
- assert g_0 is not None and H_hat is not None
258
- if n_samples > 1:
259
- g_0 /= n_samples
260
- H_hat /= n_samples
261
-
262
- # set grad to approximated grad
263
- var.grad = g_0
264
-
265
- # update H
266
- H = self.global_state.get("H", None)
267
- if H is None: H = H_hat
268
- else:
269
- beta = self.defaults["beta"]
270
- if beta is None: beta = k / (k+1)
271
- H.lerp_(H_hat, 1-beta)
272
-
273
- self.global_state["H"] = H
274
-
275
- @torch.no_grad
276
- def apply(self, var):
277
- dir = _newton_step(
278
- var=var,
279
- H = self.global_state["H"],
280
- damping = self.defaults["damping"],
281
- inner = self.children.get("inner", None),
282
- H_tfm=None,
283
- eigval_fn=self.defaults["eigval_fn"],
284
- use_lstsq=self.defaults["use_lstsq"],
285
- g_proj=None,
286
- )
287
-
288
- var.update = vec_to_tensors(dir, var.params)
289
- return var
290
-
291
- def get_H(self,var=...):
292
- return _get_H(self.global_state["H"], self.defaults["eigval_fn"])