torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
1
  """Use BFGS or maybe SR1."""
2
- from typing import Any, Literal
3
2
  from abc import ABC, abstractmethod
4
3
  from collections.abc import Mapping
4
+ from typing import Any, Literal
5
+
5
6
  import torch
6
7
 
7
- from ...core import Chainable, Module, Preconditioner, TensorwisePreconditioner
8
- from ...utils import TensorList, set_storage_
8
+ from ...core import Chainable, Module, TensorwiseTransform, Transform
9
+ from ...utils import TensorList, set_storage_, unpack_states
10
+
9
11
 
10
12
  def _safe_dict_update_(d1_:dict, d2:dict):
11
13
  inter = set(d1_.keys()).intersection(d2.keys())
@@ -17,14 +19,14 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
17
19
  elif state[key].shape != value.shape: state[key] = value
18
20
  else: state[key].lerp_(value, 1-beta)
19
21
 
20
- class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
22
+ class HessianUpdateStrategy(TensorwiseTransform, ABC):
21
23
  def __init__(
22
24
  self,
23
25
  defaults: dict | None = None,
24
26
  init_scale: float | Literal["auto"] = "auto",
25
27
  tol: float = 1e-10,
26
28
  tol_reset: bool = True,
27
- reset_interval: int | None = None,
29
+ reset_interval: int | None | Literal['auto'] = None,
28
30
  beta: float | None = None,
29
31
  update_freq: int = 1,
30
32
  scale_first: bool = True,
@@ -44,7 +46,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
44
46
  if ys != 0 and yy != 0: return yy/ys
45
47
  return 1
46
48
 
47
- def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
49
+ def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
48
50
  set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
49
51
  if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
50
52
  if init_scale >= 1:
@@ -62,7 +64,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
62
64
  raise NotImplementedError
63
65
 
64
66
  @torch.no_grad
65
- def update_tensor(self, tensor, param, grad, state, settings):
67
+ def update_tensor(self, tensor, param, grad, loss, state, settings):
66
68
  p = param.view(-1); g = tensor.view(-1)
67
69
  inverse = settings['inverse']
68
70
  M_key = 'H' if inverse else 'B'
@@ -73,6 +75,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
73
75
  tol = settings['tol']
74
76
  tol_reset = settings['tol_reset']
75
77
  reset_interval = settings['reset_interval']
78
+ if reset_interval == 'auto': reset_interval = tensor.numel() + 1
76
79
 
77
80
  if M is None:
78
81
  M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
@@ -81,10 +84,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
81
84
  else: M *= init_scale
82
85
 
83
86
  state[M_key] = M
87
+ state['f_prev'] = loss
84
88
  state['p_prev'] = p.clone()
85
89
  state['g_prev'] = g.clone()
86
90
  return
87
91
 
92
+ state['f'] = loss
88
93
  p_prev = state['p_prev']
89
94
  g_prev = state['g_prev']
90
95
  s: torch.Tensor = p - p_prev
@@ -93,13 +98,13 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
93
98
  state['g_prev'].copy_(g)
94
99
 
95
100
  if reset_interval is not None and step != 0 and step % reset_interval == 0:
96
- self._reset_M_(M, s, y, inverse, init_scale)
101
+ self._reset_M_(M, s, y, inverse, init_scale, state)
97
102
  return
98
103
 
99
104
  # tolerance on gradient difference to avoid exploding after converging
100
- elif y.abs().max() <= tol:
105
+ if y.abs().max() <= tol:
101
106
  # reset history
102
- if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
107
+ if tol_reset: self._reset_M_(M, s, y, inverse, init_scale, state)
103
108
  return
104
109
 
105
110
  if step == 1 and init_scale == 'auto':
@@ -117,12 +122,16 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
117
122
  B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
118
123
  _maybe_lerp_(state, 'B', B_new, beta)
119
124
 
125
+ state['f_prev'] = loss
126
+
120
127
  @torch.no_grad
121
- def apply_tensor(self, tensor, param, grad, state, settings):
128
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
122
129
  step = state.get('step', 0)
123
130
 
124
131
  if settings['scale_second'] and step == 2:
125
- tensor = tensor / tensor.abs().mean().clip(min=1)
132
+ scale_factor = 1 / tensor.abs().sum().clip(min=1)
133
+ scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
134
+ tensor = tensor * scale_factor
126
135
 
127
136
  inverse = settings['inverse']
128
137
  if inverse:
@@ -196,19 +205,15 @@ class SR1(HUpdateStrategy):
196
205
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
197
206
  return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
198
207
 
199
- # BFGS has defaults - init_scale = "auto" and scale_second = False
200
- # SR1 has defaults - init_scale = 1 and scale_second = True
201
- # basically some methods work better with first and some with second.
202
- # I inherit from BFGS or SR1 to avoid writing all those arguments again
203
208
  # ------------------------------------ DFP ----------------------------------- #
204
209
  def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
205
210
  sy = torch.dot(s, y)
206
211
  if sy.abs() <= tol: return H
207
212
  term1 = torch.outer(s, s).div_(sy)
208
- denom = torch.dot(y, H @ y) #
209
- if denom.abs() <= tol: return H
213
+ yHy = torch.dot(y, H @ y) #
214
+ if yHy.abs() <= tol: return H
210
215
  num = H @ torch.outer(y, y) @ H
211
- term2 = num.div_(denom)
216
+ term2 = num.div_(yHy)
212
217
  H += term1.sub_(term2)
213
218
  return H
214
219
 
@@ -223,34 +228,35 @@ class DFP(HUpdateStrategy):
223
228
 
224
229
  def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
225
230
  c = H.T @ s
226
- denom = c.dot(y)
227
- if denom.abs() <= tol: return H
231
+ cy = c.dot(y)
232
+ if cy.abs() <= tol: return H
228
233
  num = (H@y).sub_(s).outer(c)
229
- H -= num/denom
234
+ H -= num/cy
230
235
  return H
231
236
 
232
237
  def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
233
238
  c = y
234
- denom = c.dot(y)
235
- if denom.abs() <= tol: return H
239
+ cy = c.dot(y)
240
+ if cy.abs() <= tol: return H
236
241
  num = (H@y).sub_(s).outer(c)
237
- H -= num/denom
242
+ H -= num/cy
238
243
  return H
239
244
 
240
245
  def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
241
246
  c = g_prev
242
- denom = c.dot(y)
243
- if denom.abs() <= tol: return H
247
+ cy = c.dot(y)
248
+ if cy.abs() <= tol: return H
244
249
  num = (H@y).sub_(s).outer(c)
245
- H -= num/denom
250
+ H -= num/cy
246
251
  return H
247
252
 
248
253
  def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
249
- c = torch.linalg.multi_dot([H,H,y]) # pylint:disable=not-callable
250
- denom = c.dot(y)
251
- if denom.abs() <= tol: return H
252
- num = (H@y).sub_(s).outer(c)
253
- H -= num/denom
254
+ Hy = H @ y
255
+ c = H @ Hy # pylint:disable=not-callable
256
+ cy = c.dot(y)
257
+ if cy.abs() <= tol: return H
258
+ num = Hy.sub_(s).outer(c)
259
+ H -= num/cy
254
260
  return H
255
261
 
256
262
  class BroydenGood(HUpdateStrategy):
@@ -271,11 +277,7 @@ class Greenstadt2(HUpdateStrategy):
271
277
 
272
278
 
273
279
  def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
274
- n = H.shape[0]
275
-
276
280
  j = y.abs().argmax()
277
- u = torch.zeros(n, device=H.device, dtype=H.dtype)
278
- u[j] = 1.0
279
281
 
280
282
  denom = y[j]
281
283
  if denom.abs() < tol: return H
@@ -295,15 +297,15 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
295
297
  s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
296
298
  I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
297
299
  d = (R + I * (s_norm/2)) @ s
298
- denom = d.dot(s)
299
- if denom.abs() <= tol: return H, R
300
- R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(denom)))
300
+ ds = d.dot(s)
301
+ if ds.abs() <= tol: return H, R
302
+ R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
301
303
 
302
304
  c = H.T @ d
303
- denom = c.dot(y)
304
- if denom.abs() <= tol: return H, R
305
+ cy = c.dot(y)
306
+ if cy.abs() <= tol: return H, R
305
307
  num = (H@y).sub_(s).outer(c)
306
- H -= num/denom
308
+ H -= num/cy
307
309
  return H, R
308
310
 
309
311
  class ThomasOptimalMethod(HUpdateStrategy):
@@ -313,6 +315,11 @@ class ThomasOptimalMethod(HUpdateStrategy):
313
315
  H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
314
316
  return H
315
317
 
318
+ def _reset_M_(self, M, s, y,inverse, init_scale, state):
319
+ super()._reset_M_(M, s, y, inverse, init_scale, state)
320
+ for st in self.state.values():
321
+ st.pop("R", None)
322
+
316
323
  # ------------------------ powell's symmetric broyden ------------------------ #
317
324
  def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
318
325
  y_Bs = y - B@s
@@ -324,6 +331,7 @@ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
324
331
  B += term1.sub_(term2)
325
332
  return B
326
333
 
334
+ # I couldn't find formula for H
327
335
  class PSB(HessianUpdateStrategy):
328
336
  def __init__(
329
337
  self,
@@ -356,17 +364,85 @@ class PSB(HessianUpdateStrategy):
356
364
  def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
357
365
  return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
358
366
 
359
- def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
367
+
368
+ # Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
369
+ def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
370
+ Hy = H@y
371
+ yHy = y.dot(Hy)
372
+ if yHy.abs() <= tol: return H
373
+ num = (s - Hy).outer(Hy)
374
+ H += num.div_(yHy)
375
+ return H
376
+
377
+ class Pearson(HUpdateStrategy):
378
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
379
+
380
+ This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
381
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
382
+ return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
383
+
384
+ def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
360
385
  sy = s.dot(y)
361
386
  if sy.abs() <= tol: return H
362
387
  num = (s - H@y).outer(s)
363
388
  H += num.div_(sy)
364
389
  return H
365
390
 
366
- class Pearson2(HUpdateStrategy):
367
- """finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
391
+ class McCormick(HUpdateStrategy):
392
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
393
+
394
+ This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
368
395
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
369
- return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
396
+ return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
397
+
398
+ def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
399
+ Hy = H @ y
400
+ yHy = y.dot(Hy)
401
+ if yHy.abs() < tol: return H, R
402
+ H -= Hy.outer(Hy) / yHy
403
+ R += (s - R@y).outer(Hy) / yHy
404
+ return H, R
405
+
406
+ class ProjectedNewtonRaphson(HessianUpdateStrategy):
407
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
408
+
409
+ Algorithm 7"""
410
+ def __init__(
411
+ self,
412
+ init_scale: float | Literal["auto"] = 'auto',
413
+ tol: float = 1e-10,
414
+ tol_reset: bool = True,
415
+ reset_interval: int | None | Literal['auto'] = 'auto',
416
+ beta: float | None = None,
417
+ update_freq: int = 1,
418
+ scale_first: bool = True,
419
+ scale_second: bool = False,
420
+ concat_params: bool = True,
421
+ inner: Chainable | None = None,
422
+ ):
423
+ super().__init__(
424
+ init_scale=init_scale,
425
+ tol=tol,
426
+ tol_reset=tol_reset,
427
+ reset_interval=reset_interval,
428
+ beta=beta,
429
+ update_freq=update_freq,
430
+ scale_first=scale_first,
431
+ scale_second=scale_second,
432
+ concat_params=concat_params,
433
+ inverse=True,
434
+ inner=inner,
435
+ )
436
+
437
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
438
+ if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
439
+ H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
440
+ state["R"] = R
441
+ return H
442
+
443
+ def _reset_M_(self, M, s, y, inverse, init_scale, state):
444
+ assert inverse
445
+ M.copy_(state["R"])
370
446
 
371
447
  # Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
372
448
  def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
@@ -471,4 +547,137 @@ class SSVM(HessianUpdateStrategy):
471
547
  )
472
548
 
473
549
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
474
- return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
550
+ return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
551
+
552
+ # HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
553
+ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
554
+ Hy = H@y
555
+ ys = y.dot(s)
556
+ if ys.abs() <= tol: return H
557
+ yHy = y.dot(Hy)
558
+ denom = ys + yHy
559
+ if denom.abs() <= tol: return H
560
+
561
+ term1 = 1/denom
562
+ term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
563
+ term3 = s.outer(y) @ H
564
+ term4 = Hy.outer(s)
565
+ term5 = Hy.outer(y) @ H
566
+
567
+ inner_term = term2 - term3 - term4 - term5
568
+ H += inner_term.mul_(term1)
569
+ return H
570
+
571
+ def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
572
+ sy = s.dot(y)
573
+ if sy.abs() < torch.finfo(g[0].dtype).eps: return g
574
+ return g - (y * (s.dot(g) / sy))
575
+
576
+
577
+ class GradientCorrection(Transform):
578
+ """estimates gradient at minima along search direction assuming function is quadratic as proposed in HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
579
+
580
+ This can useful as inner module for second order methods."""
581
+ def __init__(self):
582
+ super().__init__(None, uses_grad=False)
583
+
584
+ def apply(self, tensors, params, grads, loss, states, settings):
585
+ if 'p_prev' not in states[0]:
586
+ p_prev = unpack_states(states, tensors, 'p_prev', init=params)
587
+ g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
588
+ return tensors
589
+
590
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
591
+ g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)
592
+
593
+ p_prev.copy_(params)
594
+ g_prev.copy_(tensors)
595
+ return g_hat
596
+
597
+ class Horisho(HUpdateStrategy):
598
+ """HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394"""
599
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
600
+ return hoshino_H_(H=H, s=s, y=y, tol=settings['tol'])
601
+
602
+ # Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
603
+ def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
604
+ sy = s.dot(y)
605
+ if sy.abs() < tol: return H
606
+ Hy = H @ y
607
+
608
+ term1 = (s.outer(y) @ H).div_(sy)
609
+ term2 = (Hy.outer(s)).div_(sy)
610
+ term3 = 1 + (y.dot(Hy) / sy)
611
+ term4 = s.outer(s).div_(sy)
612
+
613
+ H -= (term1 + term2 - term4.mul_(term3))
614
+ return H
615
+
616
+ class FletcherVMM(HUpdateStrategy):
617
+ """Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317"""
618
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
619
+ return fletcher_vmm_H_(H=H, s=s, y=y, tol=settings['tol'])
620
+
621
+
622
+ # Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
623
+ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
624
+ sy = s.dot(y)
625
+ if sy < tol: return H
626
+
627
+ term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
628
+
629
+ if type == 1:
630
+ pba = (2*sy + 2*(f-f_prev)) / sy
631
+
632
+ elif type == 2:
633
+ pba = (f_prev - f + 1/(2*sy)) / sy
634
+
635
+ else:
636
+ raise RuntimeError(type)
637
+
638
+ term3 = 1/pba + y.dot(H@y) / sy
639
+ term4 = s.outer(s) / sy
640
+
641
+ H.sub_(term1)
642
+ H.add_(term4.mul_(term3))
643
+ return H
644
+
645
+
646
+ class NewSSM(HessianUpdateStrategy):
647
+ """Self-scaling method, requires a line search.
648
+
649
+ Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U."""
650
+ def __init__(
651
+ self,
652
+ type: Literal[1, 2] = 1,
653
+ init_scale: float | Literal["auto"] = "auto",
654
+ tol: float = 1e-10,
655
+ tol_reset: bool = True,
656
+ reset_interval: int | None = None,
657
+ beta: float | None = None,
658
+ update_freq: int = 1,
659
+ scale_first: bool = True,
660
+ scale_second: bool = False,
661
+ concat_params: bool = True,
662
+ inner: Chainable | None = None,
663
+ ):
664
+ super().__init__(
665
+ defaults=dict(type=type),
666
+ init_scale=init_scale,
667
+ tol=tol,
668
+ tol_reset=tol_reset,
669
+ reset_interval=reset_interval,
670
+ beta=beta,
671
+ update_freq=update_freq,
672
+ scale_first=scale_first,
673
+ scale_second=scale_second,
674
+ concat_params=concat_params,
675
+ inverse=True,
676
+ inner=inner,
677
+ )
678
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
679
+ f = state['f']
680
+ f_prev = state['f_prev']
681
+ return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=settings['type'], tol=settings['tol'])
682
+
683
+
@@ -1,22 +1,29 @@
1
1
  import warnings
2
+ from collections.abc import Callable
2
3
  from functools import partial
3
4
  from typing import Literal
4
- from collections.abc import Callable
5
+
5
6
  import torch
6
7
 
7
- from ...core import Chainable, apply, Module
8
- from ...utils import vec_to_tensors, TensorList
8
+ from ...core import Chainable, Module, apply_transform
9
+ from ...utils import TensorList, vec_to_tensors
9
10
  from ...utils.derivatives import (
10
11
  hessian_list_to_mat,
11
12
  hessian_mat,
13
+ hvp,
14
+ hvp_fd_central,
15
+ hvp_fd_forward,
12
16
  jacobian_and_hessian_wrt,
13
17
  )
14
18
 
15
19
 
16
20
  def lu_solve(H: torch.Tensor, g: torch.Tensor):
17
- x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
18
- if info == 0: return x
19
- return None
21
+ try:
22
+ x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
23
+ if info == 0: return x
24
+ return None
25
+ except RuntimeError:
26
+ return None
20
27
 
21
28
  def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
22
29
  x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
@@ -28,10 +35,15 @@ def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
28
35
  def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
29
36
  return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
30
37
 
31
- def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
38
+ def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
32
39
  try:
33
40
  L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
34
41
  if tfm is not None: L = tfm(L)
42
+ if search_negative and L[0] < 0:
43
+ d = Q[0]
44
+ # use eigvec or -eigvec depending on if it points in same direction as gradient
45
+ return g.dot(d).sign() * d
46
+
35
47
  L.reciprocal_()
36
48
  return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
37
49
  except torch.linalg.LinAlgError:
@@ -52,6 +64,8 @@ class Newton(Module):
52
64
  Args:
53
65
  reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
54
66
  eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
67
+ search_negative (bool, Optional):
68
+ if True, whenever a negative eigenvalue is detected, the direction is taken along an eigenvector corresponding to a negative eigenvalue.
55
69
  hessian_method (str):
56
70
  how to calculate hessian. Defaults to "autograd".
57
71
  vectorize (bool, optional):
@@ -71,27 +85,29 @@ class Newton(Module):
71
85
  self,
72
86
  reg: float = 1e-6,
73
87
  eig_reg: bool = False,
88
+ search_negative: bool = False,
74
89
  hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
75
90
  vectorize: bool = True,
76
91
  inner: Chainable | None = None,
77
92
  H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
78
93
  eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
79
94
  ):
80
- defaults = dict(reg=reg, eig_reg=eig_reg, abs=abs,hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm)
95
+ defaults = dict(reg=reg, eig_reg=eig_reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative)
81
96
  super().__init__(defaults)
82
97
 
83
98
  if inner is not None:
84
99
  self.set_child('inner', inner)
85
100
 
86
101
  @torch.no_grad
87
- def step(self, vars):
88
- params = TensorList(vars.params)
89
- closure = vars.closure
102
+ def step(self, var):
103
+ params = TensorList(var.params)
104
+ closure = var.closure
90
105
  if closure is None: raise RuntimeError('NewtonCG requires closure')
91
106
 
92
107
  settings = self.settings[params[0]]
93
108
  reg = settings['reg']
94
109
  eig_reg = settings['eig_reg']
110
+ search_negative = settings['search_negative']
95
111
  hessian_method = settings['hessian_method']
96
112
  vectorize = settings['vectorize']
97
113
  H_tfm = settings['H_tfm']
@@ -100,16 +116,16 @@ class Newton(Module):
100
116
  # ------------------------ calculate grad and hessian ------------------------ #
101
117
  if hessian_method == 'autograd':
102
118
  with torch.enable_grad():
103
- loss = vars.loss = vars.loss_approx = closure(False)
119
+ loss = var.loss = var.loss_approx = closure(False)
104
120
  g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
105
121
  g_list = [t[0] for t in g_list] # remove leading dim from loss
106
- vars.grad = g_list
122
+ var.grad = g_list
107
123
  H = hessian_list_to_mat(H_list)
108
124
 
109
125
  elif hessian_method in ('func', 'autograd.functional'):
110
126
  strat = 'forward-mode' if vectorize else 'reverse-mode'
111
127
  with torch.enable_grad():
112
- g_list = vars.get_grad(retain_graph=True)
128
+ g_list = var.get_grad(retain_graph=True)
113
129
  H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
114
130
  method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
115
131
 
@@ -117,9 +133,10 @@ class Newton(Module):
117
133
  raise ValueError(hessian_method)
118
134
 
119
135
  # -------------------------------- inner step -------------------------------- #
136
+ update = var.get_update()
120
137
  if 'inner' in self.children:
121
- g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
122
- g = torch.cat([t.view(-1) for t in g_list])
138
+ update = apply_transform(self.children['inner'], update, params=params, grads=list(g_list), var=var)
139
+ g = torch.cat([t.ravel() for t in update])
123
140
 
124
141
  # ------------------------------- regulazition ------------------------------- #
125
142
  if eig_reg: H = eig_tikhonov_(H, reg)
@@ -129,14 +146,14 @@ class Newton(Module):
129
146
  update = None
130
147
  if H_tfm is not None:
131
148
  H, is_inv = H_tfm(H, g)
132
- if is_inv: update = H
149
+ if is_inv: update = H @ g
133
150
 
134
- if eigval_tfm is not None:
135
- update = eigh_solve(H, g, eigval_tfm)
151
+ if search_negative or (eigval_tfm is not None):
152
+ update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
136
153
 
137
154
  if update is None: update = cholesky_solve(H, g)
138
155
  if update is None: update = lu_solve(H, g)
139
156
  if update is None: update = least_squares_solve(H, g)
140
157
 
141
- vars.update = vec_to_tensors(update, params)
142
- return vars
158
+ var.update = vec_to_tensors(update, params)
159
+ return var
@@ -6,14 +6,14 @@ import torch
6
6
  from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
7
7
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
8
 
9
- from ...core import Chainable, apply, Module
9
+ from ...core import Chainable, apply_transform, Module
10
10
  from ...utils.linalg.solve import cg
11
11
 
12
12
  class NewtonCG(Module):
13
13
  def __init__(
14
14
  self,
15
15
  maxiter=None,
16
- tol=1e-3,
16
+ tol=1e-4,
17
17
  reg: float = 1e-8,
18
18
  hvp_method: Literal["forward", "central", "autograd"] = "forward",
19
19
  h=1e-3,
@@ -27,9 +27,9 @@ class NewtonCG(Module):
27
27
  self.set_child('inner', inner)
28
28
 
29
29
  @torch.no_grad
30
- def step(self, vars):
31
- params = TensorList(vars.params)
32
- closure = vars.closure
30
+ def step(self, var):
31
+ params = TensorList(var.params)
32
+ closure = var.closure
33
33
  if closure is None: raise RuntimeError('NewtonCG requires closure')
34
34
 
35
35
  settings = self.settings[params[0]]
@@ -42,7 +42,7 @@ class NewtonCG(Module):
42
42
 
43
43
  # ---------------------- Hessian vector product function --------------------- #
44
44
  if hvp_method == 'autograd':
45
- grad = vars.get_grad(create_graph=True)
45
+ grad = var.get_grad(create_graph=True)
46
46
 
47
47
  def H_mm(x):
48
48
  with torch.enable_grad():
@@ -51,7 +51,7 @@ class NewtonCG(Module):
51
51
  else:
52
52
 
53
53
  with torch.enable_grad():
54
- grad = vars.get_grad()
54
+ grad = var.get_grad()
55
55
 
56
56
  if hvp_method == 'forward':
57
57
  def H_mm(x):
@@ -66,19 +66,20 @@ class NewtonCG(Module):
66
66
 
67
67
 
68
68
  # -------------------------------- inner step -------------------------------- #
69
- b = grad
69
+ b = var.get_update()
70
70
  if 'inner' in self.children:
71
- b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
71
+ b = as_tensorlist(apply_transform(self.children['inner'], b, params=params, grads=grad, var=var))
72
72
 
73
73
  # ---------------------------------- run cg ---------------------------------- #
74
74
  x0 = None
75
- if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
75
+ if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
76
+
76
77
  x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
77
78
  if warm_start:
78
79
  assert x0 is not None
79
80
  x0.copy_(x)
80
81
 
81
- vars.update = x
82
- return vars
82
+ var.update = x
83
+ return var
83
84
 
84
85