torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,47 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
5
 
5
- from ...core import Chainable, Transform, apply
6
- from ...utils import TensorList, as_tensorlist
6
+ from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
7
+ from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
8
+ from .quasi_newton import _safe_clip, HessianUpdateStrategy
7
9
 
8
10
 
9
11
  class ConguateGradientBase(Transform, ABC):
10
- """all CGs are the same except beta calculation"""
11
- def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None = None, inner: Chainable | None = None):
12
+ """Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
13
+
14
+ This is an abstract class, to use it, subclass it and override `get_beta`.
15
+
16
+
17
+ Args:
18
+ defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
19
+ clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
20
+ reset_interval (int | None | Literal["auto"], optional):
21
+ interval between resetting the search direction.
22
+ "auto" means number of dimensions + 1, None means no reset. Defaults to None.
23
+ inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
24
+
25
+ Example:
26
+
27
+ .. code-block:: python
28
+
29
+ class PolakRibiere(ConguateGradientBase):
30
+ def __init__(
31
+ self,
32
+ clip_beta=True,
33
+ reset_interval: int | None = None,
34
+ inner: Chainable | None = None
35
+ ):
36
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
37
+
38
+ def get_beta(self, p, g, prev_g, prev_d):
39
+ denom = prev_g.dot(prev_g)
40
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
41
+ return g.dot(g - prev_g) / denom
42
+
43
+ """
44
+ def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
12
45
  if defaults is None: defaults = {}
13
46
  defaults['reset_interval'] = reset_interval
14
47
  defaults['clip_beta'] = clip_beta
@@ -17,6 +50,15 @@ class ConguateGradientBase(Transform, ABC):
17
50
  if inner is not None:
18
51
  self.set_child('inner', inner)
19
52
 
53
+ def reset(self):
54
+ super().reset()
55
+
56
+ def reset_for_online(self):
57
+ super().reset_for_online()
58
+ self.clear_state_keys('prev_grad')
59
+ self.global_state.pop('stage', None)
60
+ self.global_state['step'] = self.global_state.get('step', 1) - 1
61
+
20
62
  def initialize(self, p: TensorList, g: TensorList):
21
63
  """runs on first step when prev_grads and prev_dir are not available"""
22
64
 
@@ -25,38 +67,55 @@ class ConguateGradientBase(Transform, ABC):
25
67
  """returns beta"""
26
68
 
27
69
  @torch.no_grad
28
- def transform(self, tensors, params, grads, vars):
70
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
29
71
  tensors = as_tensorlist(tensors)
30
72
  params = as_tensorlist(params)
31
73
 
32
- step = self.global_state.get('step', 0)
33
- prev_dir, prev_grads = self.get_state('prev_dir', 'prev_grad', params=params, cls=TensorList)
74
+ step = self.global_state.get('step', 0) + 1
75
+ self.global_state['step'] = step
34
76
 
35
77
  # initialize on first step
36
- if step == 0:
78
+ if self.global_state.get('stage', 0) == 0:
79
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
80
+ d_prev.copy_(tensors)
81
+ g_prev.copy_(tensors)
37
82
  self.initialize(params, tensors)
38
- prev_dir.copy_(tensors)
39
- prev_grads.copy_(tensors)
40
- self.global_state['step'] = step + 1
83
+ self.global_state['stage'] = 1
84
+
85
+ else:
86
+ # if `update_tensors` was called multiple times before `apply_tensors`,
87
+ # stage becomes 2
88
+ self.global_state['stage'] = 2
89
+
90
+ @torch.no_grad
91
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
92
+ tensors = as_tensorlist(tensors)
93
+ step = self.global_state['step']
94
+
95
+ if 'inner' in self.children:
96
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
97
+
98
+ assert self.global_state['stage'] != 0
99
+ if self.global_state['stage'] == 1:
100
+ self.global_state['stage'] = 2
41
101
  return tensors
42
102
 
103
+ params = as_tensorlist(params)
104
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
105
+
43
106
  # get beta
44
- beta = self.get_beta(params, tensors, prev_grads, prev_dir)
45
- if self.settings[params[0]]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
46
- prev_grads.copy_(tensors)
107
+ beta = self.get_beta(params, tensors, g_prev, d_prev)
108
+ if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
47
109
 
48
110
  # inner step
49
- if 'inner' in self.children:
50
- tensors = as_tensorlist(apply(self.children['inner'], tensors, params, grads, vars))
51
-
52
111
  # calculate new direction with beta
53
- dir = tensors.add_(prev_dir.mul_(beta))
54
- prev_dir.copy_(dir)
112
+ dir = tensors.add_(d_prev.mul_(beta))
113
+ d_prev.copy_(dir)
55
114
 
56
115
  # resetting
57
- self.global_state['step'] = step + 1
58
- reset_interval = self.settings[params[0]]['reset_interval']
59
- if reset_interval is not None and (step+1) % reset_interval == 0:
116
+ reset_interval = settings[0]['reset_interval']
117
+ if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
118
+ if reset_interval is not None and step % reset_interval == 0:
60
119
  self.reset()
61
120
 
62
121
  return dir
@@ -68,7 +127,11 @@ def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
68
127
  return g.dot(g - prev_g) / denom
69
128
 
70
129
  class PolakRibiere(ConguateGradientBase):
71
- """Polak-Ribière-Polyak nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this."""
130
+ """Polak-Ribière-Polyak nonlinear conjugate gradient method.
131
+
132
+ .. note::
133
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
134
+ """
72
135
  def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
73
136
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
74
137
 
@@ -81,8 +144,12 @@ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
81
144
  return gg / prev_gg
82
145
 
83
146
  class FletcherReeves(ConguateGradientBase):
84
- """Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
85
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
147
+ """Fletcher–Reeves nonlinear conjugate gradient method.
148
+
149
+ .. note::
150
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
151
+ """
152
+ def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
86
153
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
87
154
 
88
155
  def initialize(self, p, g):
@@ -103,8 +170,12 @@ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
103
170
 
104
171
 
105
172
  class HestenesStiefel(ConguateGradientBase):
106
- """Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
107
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
173
+ """Hestenes–Stiefel nonlinear conjugate gradient method.
174
+
175
+ .. note::
176
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
177
+ """
178
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
108
179
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
109
180
 
110
181
  def get_beta(self, p, g, prev_g, prev_d):
@@ -118,8 +189,12 @@ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
118
189
  return (g.dot(g) / denom).neg()
119
190
 
120
191
  class DaiYuan(ConguateGradientBase):
121
- """Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
122
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
192
+ """Dai–Yuan nonlinear conjugate gradient method.
193
+
194
+ .. note::
195
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this. Although Dai–Yuan formula provides an automatic step size scaling so it is technically possible to omit line search and instead use a small step size.
196
+ """
197
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
123
198
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
124
199
 
125
200
  def get_beta(self, p, g, prev_g, prev_d):
@@ -133,8 +208,12 @@ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
133
208
  return g.dot(g - prev_g) / denom
134
209
 
135
210
  class LiuStorey(ConguateGradientBase):
136
- """Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
137
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
211
+ """Liu-Storey nonlinear conjugate gradient method.
212
+
213
+ .. note::
214
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
215
+ """
216
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
138
217
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
139
218
 
140
219
  def get_beta(self, p, g, prev_g, prev_d):
@@ -142,7 +221,11 @@ class LiuStorey(ConguateGradientBase):
142
221
 
143
222
  # ----------------------------- Conjugate Descent ---------------------------- #
144
223
  class ConjugateDescent(Transform):
145
- """Conjugate Descent (CD). This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
224
+ """Conjugate Descent (CD).
225
+
226
+ .. note::
227
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
228
+ """
146
229
  def __init__(self, inner: Chainable | None = None):
147
230
  super().__init__(defaults={}, uses_grad=False)
148
231
 
@@ -151,10 +234,10 @@ class ConjugateDescent(Transform):
151
234
 
152
235
 
153
236
  @torch.no_grad
154
- def transform(self, tensors, params, grads, vars):
237
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
155
238
  g = as_tensorlist(tensors)
156
239
 
157
- prev_d = self.get_state('prev_dir', params=params, cls=TensorList, init = torch.zeros_like)
240
+ prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
158
241
  if 'denom' not in self.global_state:
159
242
  self.global_state['denom'] = torch.tensor(0.).to(g[0])
160
243
 
@@ -164,7 +247,7 @@ class ConjugateDescent(Transform):
164
247
 
165
248
  # inner step
166
249
  if 'inner' in self.children:
167
- g = as_tensorlist(apply(self.children['inner'], g, params, grads, vars))
250
+ g = as_tensorlist(apply_transform(self.children['inner'], g, params, grads))
168
251
 
169
252
  dir = g.add_(prev_d.mul_(beta))
170
253
  prev_d.copy_(dir)
@@ -186,8 +269,11 @@ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
186
269
 
187
270
  class HagerZhang(ConguateGradientBase):
188
271
  """Hager-Zhang nonlinear conjugate gradient method,
189
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
190
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
272
+
273
+ .. note::
274
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
275
+ """
276
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
191
277
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
192
278
 
193
279
  def get_beta(self, p, g, prev_g, prev_d):
@@ -210,9 +296,74 @@ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
210
296
 
211
297
  class HybridHS_DY(ConguateGradientBase):
212
298
  """HS-DY hybrid conjugate gradient method.
213
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
214
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
299
+
300
+ .. note::
301
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
302
+ """
303
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
215
304
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
216
305
 
217
306
  def get_beta(self, p, g, prev_g, prev_d):
218
307
  return hs_dy_beta(g, prev_d, prev_g)
308
+
309
+
310
+ def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
311
+ Hy = H @ y
312
+ yHy = _safe_clip(y.dot(Hy))
313
+ H -= (Hy.outer(y) @ H) / yHy
314
+ return H
315
+
316
+ class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
317
+ """Projected gradient method.
318
+
319
+ .. note::
320
+ This method uses N^2 memory.
321
+
322
+ .. note::
323
+ This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
324
+
325
+ .. note::
326
+ This is not the same as projected gradient descent.
327
+
328
+ Reference:
329
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
330
+
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ init_scale: float | Literal["auto"] = 1,
336
+ tol: float = 1e-8,
337
+ ptol: float | None = 1e-10,
338
+ ptol_reset: bool = False,
339
+ gtol: float | None = 1e-10,
340
+ reset_interval: int | None | Literal['auto'] = 'auto',
341
+ beta: float | None = None,
342
+ update_freq: int = 1,
343
+ scale_first: bool = False,
344
+ scale_second: bool = False,
345
+ concat_params: bool = True,
346
+ # inverse: bool = True,
347
+ inner: Chainable | None = None,
348
+ ):
349
+ super().__init__(
350
+ defaults=None,
351
+ init_scale=init_scale,
352
+ tol=tol,
353
+ ptol=ptol,
354
+ ptol_reset=ptol_reset,
355
+ gtol=gtol,
356
+ reset_interval=reset_interval,
357
+ beta=beta,
358
+ update_freq=update_freq,
359
+ scale_first=scale_first,
360
+ scale_second=scale_second,
361
+ concat_params=concat_params,
362
+ inverse=True,
363
+ inner=inner,
364
+ )
365
+
366
+
367
+
368
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
369
+ return projected_gradient_(H=H, y=y)
@@ -0,0 +1,163 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+
5
+ from .quasi_newton import (
6
+ HessianUpdateStrategy,
7
+ _HessianUpdateStrategyDefaults,
8
+ _InverseHessianUpdateStrategyDefaults,
9
+ _safe_clip,
10
+ )
11
+
12
+
13
+ def _diag_Bv(self: HessianUpdateStrategy):
14
+ B, is_inverse = self.get_B()
15
+
16
+ if is_inverse:
17
+ H=B
18
+ def Hxv(v): return v/H
19
+ return Hxv
20
+
21
+ def Bv(v): return B*v
22
+ return Bv
23
+
24
+ def _diag_Hv(self: HessianUpdateStrategy):
25
+ H, is_inverse = self.get_H()
26
+
27
+ if is_inverse:
28
+ B=H
29
+ def Bxv(v): return v/B
30
+ return Bxv
31
+
32
+ def Hv(v): return H*v
33
+ return Hv
34
+
35
+ def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
36
+ sy = s.dot(y)
37
+ if sy < tol: return H
38
+
39
+ sy_sq = _safe_clip(sy**2)
40
+
41
+ num1 = (sy + (y * H * y)) * s*s
42
+ term1 = num1.div_(sy_sq)
43
+ num2 = (H * y * s).add_(s * y * H)
44
+ term2 = num2.div_(sy)
45
+ H += term1.sub_(term2)
46
+ return H
47
+
48
+ class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
49
+ """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
50
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
51
+ return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
52
+
53
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
54
+ def make_Bv(self): return _diag_Bv(self)
55
+ def make_Hv(self): return _diag_Hv(self)
56
+
57
+ def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
58
+ z = s - H*y
59
+ denom = z.dot(y)
60
+
61
+ z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
62
+ y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
63
+
64
+ # if y_norm*z_norm < tol: return H
65
+
66
+ # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
67
+ if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
68
+ H += (z*z).div_(_safe_clip(denom))
69
+ return H
70
+ class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
71
+ """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
72
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
73
+ return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
74
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
75
+ return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
76
+
77
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
78
+ def make_Bv(self): return _diag_Bv(self)
79
+ def make_Hv(self): return _diag_Hv(self)
80
+
81
+
82
+
83
+ # Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
84
+ def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
85
+ denom = _safe_clip((s**4).sum())
86
+ num = s.dot(y) - (s*B).dot(s)
87
+ B += s**2 * (num/denom)
88
+ return B
89
+
90
+ class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
91
+ """Diagonal quasi-cauchi method.
92
+
93
+ Reference:
94
+ Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
95
+ """
96
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
97
+ return diagonal_qc_B_(B=B, s=s, y=y)
98
+
99
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
100
+ def make_Bv(self): return _diag_Bv(self)
101
+ def make_Hv(self): return _diag_Hv(self)
102
+
103
+ # Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
104
+ def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
105
+ E_sq = s**2 * B**2
106
+ denom = _safe_clip((s*E_sq).dot(s))
107
+ num = s.dot(y) - (s*B).dot(s)
108
+ B += E_sq * (num/denom)
109
+ return B
110
+
111
+ class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
112
+ """Diagonal quasi-cauchi method.
113
+
114
+ Reference:
115
+ Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
116
+ """
117
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
118
+ return diagonal_wqc_B_(B=B, s=s, y=y)
119
+
120
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
121
+ def make_Bv(self): return _diag_Bv(self)
122
+ def make_Hv(self): return _diag_Hv(self)
123
+
124
+
125
+ # Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
126
+ def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
127
+ denom = _safe_clip((s**4).sum())
128
+ num = s.dot(y) + s.dot(s) - (s*B).dot(s)
129
+ B += s**2 * (num/denom) - 1
130
+ return B
131
+
132
+ class DNRTR(_HessianUpdateStrategyDefaults):
133
+ """Diagonal quasi-newton method.
134
+
135
+ Reference:
136
+ Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
137
+ """
138
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
139
+ return diagonal_wqc_B_(B=B, s=s, y=y)
140
+
141
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
142
+ def make_Bv(self): return _diag_Bv(self)
143
+ def make_Hv(self): return _diag_Hv(self)
144
+
145
+ # Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
146
+ def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
147
+ denom = _safe_clip((s**4).sum())
148
+ num = s.dot(y)
149
+ B += s**2 * (num/denom)
150
+ return B
151
+
152
+ class NewDQN(_HessianUpdateStrategyDefaults):
153
+ """Diagonal quasi-newton method.
154
+
155
+ Reference:
156
+ Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
157
+ """
158
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
159
+ return new_dqn_B_(B=B, s=s, y=y)
160
+
161
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
162
+ def make_Bv(self): return _diag_Bv(self)
163
+ def make_Hv(self): return _diag_Hv(self)