torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,355 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from ...core import (
7
+ Chainable,
8
+ Modular,
9
+ Module,
10
+ Transform,
11
+ Var,
12
+ apply_transform,
13
+ )
14
+ from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
15
+ from ..line_search import LineSearchBase
16
+ from ..quasi_newton.quasi_newton import HessianUpdateStrategy
17
+ from ..functional import safe_clip
18
+
19
+
20
+ class ConguateGradientBase(Transform, ABC):
21
+ """Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
22
+
23
+ This is an abstract class, to use it, subclass it and override `get_beta`.
24
+
25
+
26
+ Args:
27
+ defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
28
+ clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
29
+ restart_interval (int | None | Literal["auto"], optional):
30
+ interval between resetting the search direction.
31
+ "auto" means number of dimensions + 1, None means no reset. Defaults to None.
32
+ inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
33
+
34
+ Example:
35
+
36
+ ```python
37
+ class PolakRibiere(ConguateGradientBase):
38
+ def __init__(
39
+ self,
40
+ clip_beta=True,
41
+ restart_interval: int | None = None,
42
+ inner: Chainable | None = None
43
+ ):
44
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
45
+
46
+ def get_beta(self, p, g, prev_g, prev_d):
47
+ denom = prev_g.dot(prev_g)
48
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
49
+ return g.dot(g - prev_g) / denom
50
+ ```
51
+
52
+ """
53
+ def __init__(self, defaults = None, clip_beta: bool = False, restart_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
54
+ if defaults is None: defaults = {}
55
+ defaults['restart_interval'] = restart_interval
56
+ defaults['clip_beta'] = clip_beta
57
+ super().__init__(defaults, uses_grad=False)
58
+
59
+ if inner is not None:
60
+ self.set_child('inner', inner)
61
+
62
+
63
+ def reset_for_online(self):
64
+ super().reset_for_online()
65
+ self.clear_state_keys('prev_grad')
66
+ self.global_state.pop('stage', None)
67
+ self.global_state['step'] = self.global_state.get('step', 1) - 1
68
+
69
+ def initialize(self, p: TensorList, g: TensorList):
70
+ """runs on first step when prev_grads and prev_dir are not available"""
71
+
72
+ @abstractmethod
73
+ def get_beta(self, p: TensorList, g: TensorList, prev_g: TensorList, prev_d: TensorList) -> float | torch.Tensor:
74
+ """returns beta"""
75
+
76
+ @torch.no_grad
77
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
78
+ tensors = as_tensorlist(tensors)
79
+ params = as_tensorlist(params)
80
+
81
+ step = self.global_state.get('step', 0) + 1
82
+ self.global_state['step'] = step
83
+
84
+ # initialize on first step
85
+ if self.global_state.get('stage', 0) == 0:
86
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
87
+ d_prev.copy_(tensors)
88
+ g_prev.copy_(tensors)
89
+ self.initialize(params, tensors)
90
+ self.global_state['stage'] = 1
91
+
92
+ else:
93
+ # if `update_tensors` was called multiple times before `apply_tensors`,
94
+ # stage becomes 2
95
+ self.global_state['stage'] = 2
96
+
97
+ @torch.no_grad
98
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
99
+ tensors = as_tensorlist(tensors)
100
+ step = self.global_state['step']
101
+
102
+ if 'inner' in self.children:
103
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
104
+
105
+ assert self.global_state['stage'] != 0
106
+ if self.global_state['stage'] == 1:
107
+ self.global_state['stage'] = 2
108
+ return tensors
109
+
110
+ params = as_tensorlist(params)
111
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
112
+
113
+ # get beta
114
+ beta = self.get_beta(params, tensors, g_prev, d_prev)
115
+ if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
116
+
117
+ # inner step
118
+ # calculate new direction with beta
119
+ dir = tensors.add_(d_prev.mul_(beta))
120
+ d_prev.copy_(dir)
121
+
122
+ # resetting
123
+ restart_interval = settings[0]['restart_interval']
124
+ if restart_interval == 'auto': restart_interval = tensors.global_numel() + 1
125
+ if restart_interval is not None and step % restart_interval == 0:
126
+ self.state.clear()
127
+ self.global_state.clear()
128
+
129
+ return dir
130
+
131
+ # ------------------------------- Polak-Ribière ------------------------------ #
132
+ def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
133
+ denom = prev_g.dot(prev_g)
134
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
135
+ return g.dot(g - prev_g) / denom
136
+
137
+ class PolakRibiere(ConguateGradientBase):
138
+ """Polak-Ribière-Polyak nonlinear conjugate gradient method.
139
+
140
+ Note:
141
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
142
+ """
143
+ def __init__(self, clip_beta=True, restart_interval: int | None = None, inner: Chainable | None = None):
144
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
145
+
146
+ def get_beta(self, p, g, prev_g, prev_d):
147
+ return polak_ribiere_beta(g, prev_g)
148
+
149
+ # ------------------------------ Fletcher–Reeves ----------------------------- #
150
+ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
151
+ if prev_gg.abs() <= torch.finfo(gg.dtype).tiny * 2: return 0
152
+ return gg / prev_gg
153
+
154
+ class FletcherReeves(ConguateGradientBase):
155
+ """Fletcher–Reeves nonlinear conjugate gradient method.
156
+
157
+ Note:
158
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
159
+ """
160
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
161
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
162
+
163
+ def initialize(self, p, g):
164
+ self.global_state['prev_gg'] = g.dot(g)
165
+
166
+ def get_beta(self, p, g, prev_g, prev_d):
167
+ gg = g.dot(g)
168
+ beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
169
+ self.global_state['prev_gg'] = gg
170
+ return beta
171
+
172
+ # ----------------------------- Hestenes–Stiefel ----------------------------- #
173
+ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
174
+ grad_diff = g - prev_g
175
+ denom = prev_d.dot(grad_diff)
176
+ if denom.abs() < torch.finfo(g[0].dtype).tiny * 2: return 0
177
+ return (g.dot(grad_diff) / denom).neg()
178
+
179
+
180
+ class HestenesStiefel(ConguateGradientBase):
181
+ """Hestenes–Stiefel nonlinear conjugate gradient method.
182
+
183
+ Note:
184
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
185
+ """
186
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
187
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
188
+
189
+ def get_beta(self, p, g, prev_g, prev_d):
190
+ return hestenes_stiefel_beta(g, prev_d, prev_g)
191
+
192
+
193
+ # --------------------------------- Dai–Yuan --------------------------------- #
194
+ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
195
+ denom = prev_d.dot(g - prev_g)
196
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
197
+ return (g.dot(g) / denom).neg()
198
+
199
+ class DaiYuan(ConguateGradientBase):
200
+ """Dai–Yuan nonlinear conjugate gradient method.
201
+
202
+ Note:
203
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1)`` after this.
204
+ """
205
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
206
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
207
+
208
+ def get_beta(self, p, g, prev_g, prev_d):
209
+ return dai_yuan_beta(g, prev_d, prev_g)
210
+
211
+
212
+ # -------------------------------- Liu-Storey -------------------------------- #
213
+ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
214
+ denom = prev_g.dot(prev_d)
215
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
216
+ return g.dot(g - prev_g) / denom
217
+
218
+ class LiuStorey(ConguateGradientBase):
219
+ """Liu-Storey nonlinear conjugate gradient method.
220
+
221
+ Note:
222
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
223
+ """
224
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
225
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
226
+
227
+ def get_beta(self, p, g, prev_g, prev_d):
228
+ return liu_storey_beta(g, prev_d, prev_g)
229
+
230
+ # ----------------------------- Conjugate Descent ---------------------------- #
231
+ def conjugate_descent_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList):
232
+ denom = prev_g.dot(prev_d)
233
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
234
+ return g.dot(g) / denom
235
+
236
+ class ConjugateDescent(ConguateGradientBase):
237
+ """Conjugate Descent (CD).
238
+
239
+ Note:
240
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
241
+ """
242
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
243
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
244
+
245
+ def get_beta(self, p, g, prev_g, prev_d):
246
+ return conjugate_descent_beta(g, prev_d, prev_g)
247
+
248
+
249
+ # -------------------------------- Hager-Zhang ------------------------------- #
250
+ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
251
+ g_diff = g - prev_g
252
+ denom = prev_d.dot(g_diff)
253
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
254
+
255
+ term1 = 1/denom
256
+ # term2
257
+ term2 = (g_diff - (2 * prev_d * (g_diff.pow(2).global_sum()/denom))).dot(g)
258
+ return (term1 * term2).neg()
259
+
260
+
261
+ class HagerZhang(ConguateGradientBase):
262
+ """Hager-Zhang nonlinear conjugate gradient method,
263
+
264
+ Note:
265
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
266
+ """
267
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
268
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
269
+
270
+ def get_beta(self, p, g, prev_g, prev_d):
271
+ return hager_zhang_beta(g, prev_d, prev_g)
272
+
273
+
274
+ # ----------------------------------- DYHS ---------------------------------- #
275
+ def dyhs_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
276
+ grad_diff = g - prev_g
277
+ denom = prev_d.dot(grad_diff)
278
+ if denom.abs() <= torch.finfo(g[0].dtype).tiny * 2: return 0
279
+
280
+ # Dai-Yuan
281
+ dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
282
+
283
+ # Hestenes–Stiefel
284
+ hs_beta = (g.dot(grad_diff) / denom).neg().clamp(min=0)
285
+
286
+ return max(0, min(dy_beta, hs_beta)) # type:ignore
287
+
288
+ class DYHS(ConguateGradientBase):
289
+ """Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.
290
+
291
+ Note:
292
+ This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
293
+ """
294
+ def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
295
+ super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
296
+
297
+ def get_beta(self, p, g, prev_g, prev_d):
298
+ return dyhs_beta(g, prev_d, prev_g)
299
+
300
+
301
+ def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
302
+ Hy = H @ y
303
+ yHy = safe_clip(y.dot(Hy))
304
+ H -= (Hy.outer(y) @ H) / yHy
305
+ return H
306
+
307
+ class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
308
+ """Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.
309
+
310
+ Notes:
311
+ - This method uses N^2 memory.
312
+ - This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
313
+ - This is not the same as projected gradient descent.
314
+
315
+ Reference:
316
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171. (algorithm 5 in section 6)
317
+
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ init_scale: float | Literal["auto"] = 1,
323
+ tol: float = 1e-32,
324
+ ptol: float | None = 1e-32,
325
+ ptol_restart: bool = False,
326
+ gtol: float | None = 1e-32,
327
+ restart_interval: int | None | Literal['auto'] = 'auto',
328
+ beta: float | None = None,
329
+ update_freq: int = 1,
330
+ scale_first: bool = False,
331
+ concat_params: bool = True,
332
+ # inverse: bool = True,
333
+ inner: Chainable | None = None,
334
+ ):
335
+ super().__init__(
336
+ defaults=None,
337
+ init_scale=init_scale,
338
+ tol=tol,
339
+ ptol=ptol,
340
+ ptol_restart=ptol_restart,
341
+ gtol=gtol,
342
+ restart_interval=restart_interval,
343
+ beta=beta,
344
+ update_freq=update_freq,
345
+ scale_first=scale_first,
346
+ concat_params=concat_params,
347
+ inverse=True,
348
+ inner=inner,
349
+ )
350
+
351
+
352
+
353
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
354
+ return projected_gradient_(H=H, y=y)
355
+
@@ -1,24 +1,18 @@
1
- from .absoap import ABSOAP
2
- from .adadam import Adadam
3
- from .adamY import AdamY
4
- from .adasoap import AdaSOAP
1
+ """Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
5
2
  from .curveball import CurveBall
6
- from .eigendescent import EigenDescent
7
- from .etf import (
8
- ExponentialTrajectoryFit,
9
- ExponentialTrajectoryFitV2,
10
- PointwiseExponential,
11
- )
3
+
4
+ # from dct import DCTProjection
5
+ from .fft import FFTProjection
12
6
  from .gradmin import GradMin
7
+ from .l_infinity import InfinityNormTrustRegion
8
+ from .momentum import (
9
+ CoordinateMomentum,
10
+ NesterovEMASquared,
11
+ PrecenteredEMASquared,
12
+ SqrtNesterovEMASquared,
13
+ )
13
14
  from .newton_solver import NewtonSolver
14
15
  from .newtonnewton import NewtonNewton
15
16
  from .reduce_outward_lr import ReduceOutwardLR
16
- from .soapy import SOAPY
17
- from .spectral import SpectralPreconditioner
18
- from .structured_newton import StructuredNewton
19
- from .subspace_preconditioners import (
20
- HistorySubspacePreconditioning,
21
- RandomSubspacePreconditioning,
22
- )
23
- from .tada import TAda
24
- from .diagonal_higher_order_newton import DiagonalHigherOrderNewton
17
+ from .scipy_newton_cg import ScipyNewtonCG
18
+ from .structural_projections import BlockPartition, TensorizeProjection
@@ -1,13 +1,13 @@
1
1
  from typing import Literal
2
2
  import torch
3
3
  import torch_dct
4
- from .projection import Projection
4
+ from ..projections import ProjectionBase
5
5
  from ...core import Chainable
6
6
 
7
7
  def reverse_dims(t:torch.Tensor):
8
8
  return t.permute(*reversed(range(t.ndim)))
9
9
 
10
- class DCTProjection(Projection):
10
+ class DCTProjection(ProjectionBase):
11
11
  # norm description copied from pytorch docstring
12
12
  """Project update into Discrete Cosine Transform space, requires `torch_dct` library.
13
13
 
@@ -34,8 +34,8 @@ class DCTProjection(Projection):
34
34
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
35
35
 
36
36
  @torch.no_grad
37
- def project(self, tensors, var, current):
38
- settings = self.settings[var.params[0]]
37
+ def project(self, tensors, params, grads, loss, states, settings, current):
38
+ settings = settings[0]
39
39
  dims = settings['dims']
40
40
  norm = settings['norm']
41
41
 
@@ -54,18 +54,18 @@ class DCTProjection(Projection):
54
54
  return projected
55
55
 
56
56
  @torch.no_grad
57
- def unproject(self, tensors, var, current):
58
- settings = self.settings[var.params[0]]
57
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
58
+ settings = settings[0]
59
59
  dims = settings['dims']
60
60
  norm = settings['norm']
61
61
 
62
62
  unprojected = []
63
- for u in tensors:
64
- dim = min(u.ndim, dims)
63
+ for t in projected_tensors:
64
+ dim = min(t.ndim, dims)
65
65
 
66
- if dim == 1: idct = torch_dct.idct(u, norm = norm)
67
- elif dim == 2: idct = torch_dct.idct_2d(u, norm=norm)
68
- elif dim == 3: idct = torch_dct.idct_3d(u, norm=norm)
66
+ if dim == 1: idct = torch_dct.idct(t, norm = norm)
67
+ elif dim == 2: idct = torch_dct.idct_2d(t, norm=norm)
68
+ elif dim == 3: idct = torch_dct.idct_3d(t, norm=norm)
69
69
  else: raise ValueError(f"Unsupported number of dimensions {dim}")
70
70
 
71
71
  unprojected.append(reverse_dims(idct))
@@ -2,12 +2,12 @@ import torch
2
2
 
3
3
  from ...core import Chainable
4
4
  from ...utils import vec_to_tensors
5
- from .projection import Projection
5
+ from ..projections import ProjectionBase
6
6
 
7
7
 
8
- class FFTProjection(Projection):
8
+ class FFTProjection(ProjectionBase):
9
9
  # norm description copied from pytorch docstring
10
- """Project update into Fourrier space of real-valued inputs.
10
+ """Project update into Fourier space of real-valued inputs.
11
11
 
12
12
  Args:
13
13
  modules (Chainable): modules that will optimize the projected update.
@@ -45,8 +45,8 @@ class FFTProjection(Projection):
45
45
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
46
46
 
47
47
  @torch.no_grad
48
- def project(self, tensors, var, current):
49
- settings = self.settings[var.params[0]]
48
+ def project(self, tensors, params, grads, loss, states, settings, current):
49
+ settings = settings[0]
50
50
  one_d = settings['one_d']
51
51
  norm = settings['norm']
52
52
 
@@ -60,14 +60,14 @@ class FFTProjection(Projection):
60
60
  return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
61
61
 
62
62
  @torch.no_grad
63
- def unproject(self, tensors, var, current):
64
- settings = self.settings[var.params[0]]
63
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
64
+ settings = settings[0]
65
65
  one_d = settings['one_d']
66
66
  norm = settings['norm']
67
67
 
68
68
  if one_d:
69
- vec = torch.view_as_complex(tensors[0])
69
+ vec = torch.view_as_complex(projected_tensors[0])
70
70
  unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
71
- return vec_to_tensors(unprojected_vec, reference=var.params)
71
+ return vec_to_tensors(unprojected_vec, reference=params)
72
72
 
73
- return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors, var.params)] # pylint:disable=not-callable
73
+ return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(projected_tensors, params)] # pylint:disable=not-callable
@@ -5,11 +5,11 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var
8
+ from ...core import Module, Var, Chainable
9
9
  from ...utils import NumberList, TensorList
10
10
  from ...utils.derivatives import jacobian_wrt
11
11
  from ..grad_approximation import GradApproximator, GradTarget
12
- from ..smoothing.gaussian import Reformulation
12
+ from ..smoothing.sampling import Reformulation
13
13
 
14
14
 
15
15
 
@@ -28,6 +28,7 @@ class GradMin(Reformulation):
28
28
  """
29
29
  def __init__(
30
30
  self,
31
+ modules: Chainable,
31
32
  loss_term: float | None = 0,
32
33
  relative: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
33
34
  graft: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
@@ -39,7 +40,7 @@ class GradMin(Reformulation):
39
40
  ):
40
41
  if (relative is not None) and (graft is not None): warnings.warn('both relative and graft loss are True, they will clash with each other')
41
42
  defaults = dict(loss_term=loss_term, relative=relative, graft=graft, square=square, mean=mean, maximize_grad=maximize_grad, create_graph=create_graph, modify_loss=modify_loss)
42
- super().__init__(defaults)
43
+ super().__init__(defaults, modules=modules)
43
44
 
44
45
  @torch.no_grad
45
46
  def closure(self, backward, closure, params, var):
@@ -0,0 +1,111 @@
1
+
2
+ import numpy as np
3
+ import torch
4
+ from scipy.optimize import lsq_linear
5
+
6
+ from ...core import Chainable, Module
7
+ from ..trust_region.trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
8
+
9
+
10
+ class InfinityNormTrustRegion(TrustRegionBase):
11
+ """Trust region with L-infinity norm via ``scipy.optimize.lsq_linear``.
12
+
13
+ Args:
14
+ hess_module (Module | None, optional):
15
+ A module that maintains a hessian approximation (not hessian inverse!).
16
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
17
+ When using quasi-newton methods, set `inverse=False` when constructing them.
18
+ eta (float, optional):
19
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
20
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
21
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
22
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
23
+ rho_good (float, optional):
24
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
25
+ rho_bad (float, optional):
26
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
27
+ init (float, optional): Initial trust region value. Defaults to 1.
28
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
29
+ max_attempts (max_attempts, optional):
30
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
31
+ this limit is exceeded. Defaults to 10.
32
+ boundary_tol (float | None, optional):
33
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
34
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
35
+ tol (float | None, optional): tolerance for least squares solver.
36
+ fallback (bool, optional):
37
+ if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
38
+ be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
39
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
40
+
41
+ Examples:
42
+ BFGS with infinity-norm trust region
43
+
44
+ .. code-block:: python
45
+
46
+ opt = tz.Modular(
47
+ model.parameters(),
48
+ tz.m.InfinityNormTrustRegion(hess_module=tz.m.BFGS(inverse=False)),
49
+ )
50
+ """
51
+ def __init__(
52
+ self,
53
+ hess_module: Module,
54
+ prefer_dense:bool=True,
55
+ tol: float = 1e-10,
56
+ eta: float= 0.0,
57
+ nplus: float = 3.5,
58
+ nminus: float = 0.25,
59
+ rho_good: float = 0.99,
60
+ rho_bad: float = 1e-4,
61
+ boundary_tol: float | None = None,
62
+ init: float = 1,
63
+ max_attempts: int = 10,
64
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
65
+ update_freq: int = 1,
66
+ inner: Chainable | None = None,
67
+ ):
68
+ defaults = dict(tol=tol, prefer_dense=prefer_dense)
69
+ super().__init__(
70
+ defaults=defaults,
71
+ hess_module=hess_module,
72
+ eta=eta,
73
+ nplus=nplus,
74
+ nminus=nminus,
75
+ rho_good=rho_good,
76
+ rho_bad=rho_bad,
77
+ boundary_tol=boundary_tol,
78
+ init=init,
79
+ max_attempts=max_attempts,
80
+ radius_strategy=radius_strategy,
81
+ update_freq=update_freq,
82
+ inner=inner,
83
+
84
+ radius_fn=torch.amax,
85
+ )
86
+
87
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
88
+ if settings['prefer_dense'] and H.is_dense():
89
+ # convert to array if possible to avoid many conversions
90
+ # between torch and numpy, plus it seems that it uses
91
+ # a better solver
92
+ A = H.to_tensor().numpy(force=True).astype(np.float64)
93
+ else:
94
+ # memory efficient linear operator (is this still faster on CUDA?)
95
+ A = H.scipy_linop()
96
+
97
+ try:
98
+ d_np = lsq_linear(
99
+ A,
100
+ g.numpy(force=True).astype(np.float64),
101
+ tol=settings['bounds'],
102
+ bounds=(-radius, radius),
103
+ ).x
104
+ return torch.as_tensor(d_np, device=g.device, dtype=g.dtype)
105
+
106
+ except np.linalg.LinAlgError:
107
+ self.children['hess_module'].reset()
108
+ g_max = g.amax()
109
+ if g_max > radius:
110
+ g = g * (radius / g_max)
111
+ return g