torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,12 @@
1
- import itertools
2
1
  import math
3
- import warnings
4
- from collections.abc import Callable
5
- from contextlib import nullcontext
6
- from functools import partial
7
2
  from typing import Any, Literal
8
3
 
9
4
  import numpy as np
10
5
  import scipy.optimize
11
6
  import torch
12
7
 
13
- from ...core import Chainable, Module, apply_transform
8
+ from ...core import DerivativesMethod, Module
14
9
  from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
- from ...utils.derivatives import (
16
- flatten_jacobian,
17
- jacobian_wrt,
18
- )
19
10
 
20
11
  _LETTERS = 'abcdefghijklmnopqrstuvwxyz'
21
12
  def _poly_eval(s: np.ndarray, c, derivatives):
@@ -195,22 +186,22 @@ class HigherOrderNewton(Module):
195
186
  max_attempts = 10,
196
187
  boundary_tol: float = 1e-2,
197
188
  de_iters: int | None = None,
198
- vectorize: bool = True,
189
+ derivatives_method: DerivativesMethod = "batched_autograd",
199
190
  ):
200
191
  if init is None:
201
192
  if trust_method == 'bounds': init = 1
202
193
  else: init = 0.1
203
194
 
204
- defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
195
+ defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad, derivatives_method=derivatives_method)
205
196
  super().__init__(defaults)
206
197
 
207
198
  @torch.no_grad
208
- def step(self, var):
209
- params = TensorList(var.params)
210
- closure = var.closure
199
+ def apply(self, objective):
200
+ params = TensorList(objective.params)
201
+ closure = objective.closure
211
202
  if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
212
203
 
213
- settings = self.settings[params[0]]
204
+ settings = self.defaults
214
205
  order = settings['order']
215
206
  nplus = settings['nplus']
216
207
  nminus = settings['nminus']
@@ -219,31 +210,12 @@ class HigherOrderNewton(Module):
219
210
  trust_method = settings['trust_method']
220
211
  de_iters = settings['de_iters']
221
212
  max_attempts = settings['max_attempts']
222
- vectorize = settings['vectorize']
223
213
  boundary_tol = settings['boundary_tol']
224
214
  rho_good = settings['rho_good']
225
215
  rho_bad = settings['rho_bad']
226
216
 
227
217
  # ------------------------ calculate grad and hessian ------------------------ #
228
- with torch.enable_grad():
229
- loss = var.loss = var.loss_approx = closure(False)
230
-
231
- g_list = torch.autograd.grad(loss, params, create_graph=True)
232
- var.grad = list(g_list)
233
-
234
- g = torch.cat([t.ravel() for t in g_list])
235
- n = g.numel()
236
- derivatives = [g]
237
- T = g # current derivatives tensor
238
-
239
- # get all derivative up to order
240
- for o in range(2, order + 1):
241
- is_last = o == order
242
- T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
243
- with torch.no_grad() if is_last else nullcontext():
244
- # the shape is (ndim, ) * order
245
- T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
246
- derivatives.append(T)
218
+ loss, *derivatives = objective.derivatives(order=order, at_x0=True, method=self.defaults["derivatives_method"])
247
219
 
248
220
  x0 = torch.cat([p.ravel() for p in params])
249
221
 
@@ -301,7 +273,8 @@ class HigherOrderNewton(Module):
301
273
  vec_to_tensors_(x0, params)
302
274
  reduction = loss - loss_star
303
275
 
304
- rho = reduction / (max(pred_reduction, 1e-8))
276
+ rho = reduction / (max(pred_reduction, finfo.tiny * 2)) # pyright:ignore[reportArgumentType]
277
+
305
278
  # failed step
306
279
  if rho < rho_bad:
307
280
  self.global_state['trust_region'] = trust_value * nminus
@@ -320,8 +293,9 @@ class HigherOrderNewton(Module):
320
293
  assert x_star is not None
321
294
  if success:
322
295
  difference = vec_to_tensors(x0 - x_star, params)
323
- var.update = list(difference)
296
+ objective.updates = list(difference)
324
297
  else:
325
- var.update = params.zeros_like()
326
- return var
298
+ objective.updates = params.zeros_like()
299
+
300
+ return objective
327
301
 
@@ -43,7 +43,7 @@ class InfinityNormTrustRegion(TrustRegionBase):
43
43
 
44
44
  .. code-block:: python
45
45
 
46
- opt = tz.Modular(
46
+ opt = tz.Optimizer(
47
47
  model.parameters(),
48
48
  tz.m.InfinityNormTrustRegion(hess_module=tz.m.BFGS(inverse=False)),
49
49
  )
@@ -0,0 +1,122 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from torchzero.core import Chainable, Transform, HVPMethod
6
+ from torchzero.utils import NumberList, TensorList
7
+
8
+
9
+ def matrix_nag_(
10
+ tensors_: TensorList,
11
+ s: TensorList,
12
+ Hvp_fn: Callable,
13
+ mu: float | NumberList,
14
+ ):
15
+ s += tensors_
16
+ Hv = TensorList(Hvp_fn(s))
17
+ s -= Hv.mul_(mu)
18
+ return tensors_.add_(s)
19
+
20
+
21
+ class MatrixNAG(Transform):
22
+ """nesterov momentum version of matrix momentum. It seemed to work really well but adapting doesn't work,
23
+ I need to test more"""
24
+ def __init__(
25
+ self,
26
+ mu=0.1,
27
+ hvp_method: HVPMethod = "autograd",
28
+ h: float = 1e-3,
29
+ adaptive:bool = False,
30
+ adapt_freq: int | None = None,
31
+ hvp_tfm: Chainable | None = None,
32
+ ):
33
+ defaults = dict(mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
34
+ super().__init__(defaults)
35
+
36
+ if hvp_tfm is not None:
37
+ self.set_child('hvp_tfm', hvp_tfm)
38
+
39
+ def reset_for_online(self):
40
+ super().reset_for_online()
41
+ self.clear_state_keys('p_prev')
42
+
43
+ @torch.no_grad
44
+ def apply_states(self, objective, states, settings):
45
+ assert objective.closure is not None
46
+ step = self.global_state.get("step", 0)
47
+ self.global_state["step"] = step + 1
48
+
49
+ p = TensorList(objective.params)
50
+ g = TensorList(objective.get_grads(create_graph=self.defaults["hvp_method"] == "autograd"))
51
+ p_prev = self.get_state(p, "p_prev", init=p, cls=TensorList)
52
+ s = p - p_prev
53
+ p_prev.copy_(p)
54
+
55
+ # -------------------------------- adaptive mu ------------------------------- #
56
+ if self.defaults["adaptive"]:
57
+
58
+ if step == 1:
59
+ self.global_state["mu_mul"] = 0
60
+
61
+ else:
62
+ # ---------------------------- deterministic case ---------------------------- #
63
+ if self.defaults["adapt_freq"] is None:
64
+ g_prev = self.get_state(objective.params, "g_prev", cls=TensorList)
65
+ y = g - g_prev
66
+ g_prev.copy_(g)
67
+
68
+ denom = y.global_vector_norm()
69
+ denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
70
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
71
+
72
+ # -------------------------------- stochastic -------------------------------- #
73
+ else:
74
+ adapt_freq = self.defaults["adapt_freq"]
75
+
76
+ # we start on 1nd step, and want to adapt when we start, so use (step - 1)
77
+ if (step - 1) % adapt_freq == 0:
78
+ assert objective.closure is not None
79
+ p_cur = p.clone()
80
+
81
+ # move to previous params and evaluate p_prev with current mini-batch
82
+ p.copy_(self.get_state(objective.params, 'p_prev'))
83
+ with torch.enable_grad():
84
+ objective.closure()
85
+ g_prev = [t.grad if t.grad is not None else torch.zeros_like(t) for t in p]
86
+ y = g - g_prev
87
+
88
+ # move back to current params
89
+ p.copy_(p_cur)
90
+
91
+ denom = y.global_vector_norm()
92
+ denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
93
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
94
+
95
+ # -------------------------- matrix momentum update -------------------------- #
96
+ mu = self.get_settings(p, "mu", cls=NumberList)
97
+ if "mu_mul" in self.global_state:
98
+ mu = mu * self.global_state["mu_mul"]
99
+
100
+ # def Hvp_fn(v):
101
+ # Hv, _ = self.Hvp(
102
+ # v=v,
103
+ # at_x0=True,
104
+ # var=objective,
105
+ # rgrad=g,
106
+ # hvp_method=self.defaults["hvp_method"],
107
+ # h=self.defaults["h"],
108
+ # normalize=True,
109
+ # retain_grad=False,
110
+ # )
111
+ # return Hv
112
+
113
+ _, Hvp_fn = objective.list_Hvp_function(hvp_method=self.defaults["hvp_method"], h=self.defaults["h"], at_x0=True)
114
+
115
+ objective.updates = matrix_nag_(
116
+ tensors_=TensorList(objective.get_updates()),
117
+ s=s,
118
+ Hvp_fn=Hvp_fn,
119
+ mu=mu,
120
+ )
121
+
122
+ return objective
@@ -1,11 +1,10 @@
1
- from collections.abc import Callable, Iterable
2
- from typing import Any, Literal, overload
1
+ from collections.abc import Callable
2
+ from typing import Any
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Modular, Module, apply_transform
7
- from ...utils import TensorList, as_tensorlist
8
- from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
6
+ from ...core import Chainable, Optimizer, Module, step, HVPMethod
7
+ from ...utils import TensorList
9
8
  from ..quasi_newton import LBFGS
10
9
 
11
10
 
@@ -13,30 +12,32 @@ class NewtonSolver(Module):
13
12
  """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
14
13
  def __init__(
15
14
  self,
16
- solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
15
+ solver: Callable[[list[torch.Tensor]], Any] = lambda p: Optimizer(p, LBFGS()),
17
16
  maxiter=None,
18
17
  maxiter1=None,
19
18
  tol:float | None=1e-3,
20
19
  reg: float = 0,
21
20
  warm_start=True,
22
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
21
+ hvp_method: HVPMethod = "autograd",
23
22
  reset_solver: bool = False,
24
23
  h: float= 1e-3,
24
+
25
25
  inner: Chainable | None = None,
26
26
  ):
27
- defaults = dict(tol=tol, h=h,reset_solver=reset_solver, maxiter=maxiter, maxiter1=maxiter1, reg=reg, warm_start=warm_start, solver=solver, hvp_method=hvp_method)
28
- super().__init__(defaults,)
27
+ defaults = locals().copy()
28
+ del defaults['self'], defaults['inner']
29
+ super().__init__(defaults)
29
30
 
30
- if inner is not None:
31
- self.set_child('inner', inner)
31
+ self.set_child("inner", inner)
32
32
 
33
33
  self._num_hvps = 0
34
34
  self._num_hvps_last_step = 0
35
35
 
36
36
  @torch.no_grad
37
- def step(self, var):
38
- params = TensorList(var.params)
39
- closure = var.closure
37
+ def apply(self, objective):
38
+
39
+ params = TensorList(objective.params)
40
+ closure = objective.closure
40
41
  if closure is None: raise RuntimeError('NewtonCG requires closure')
41
42
 
42
43
  settings = self.settings[params[0]]
@@ -44,51 +45,19 @@ class NewtonSolver(Module):
44
45
  maxiter = settings['maxiter']
45
46
  maxiter1 = settings['maxiter1']
46
47
  tol = settings['tol']
47
- reg = settings['reg']
48
48
  hvp_method = settings['hvp_method']
49
49
  warm_start = settings['warm_start']
50
50
  h = settings['h']
51
51
  reset_solver = settings['reset_solver']
52
52
 
53
53
  self._num_hvps_last_step = 0
54
- # ---------------------- Hessian vector product function --------------------- #
55
- if hvp_method == 'autograd':
56
- grad = var.get_grad(create_graph=True)
57
-
58
- def H_mm(x):
59
- self._num_hvps_last_step += 1
60
- with torch.enable_grad():
61
- Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
62
- if reg != 0: Hvp = Hvp + (x*reg)
63
- return Hvp
64
-
65
- else:
66
-
67
- with torch.enable_grad():
68
- grad = var.get_grad()
69
-
70
- if hvp_method == 'forward':
71
- def H_mm(x):
72
- self._num_hvps_last_step += 1
73
- Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
74
- if reg != 0: Hvp = Hvp + (x*reg)
75
- return Hvp
76
-
77
- elif hvp_method == 'central':
78
- def H_mm(x):
79
- self._num_hvps_last_step += 1
80
- Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
81
- if reg != 0: Hvp = Hvp + (x*reg)
82
- return Hvp
83
-
84
- else:
85
- raise ValueError(hvp_method)
86
54
 
55
+ # ---------------------- Hessian vector product function --------------------- #
56
+ _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
87
57
 
88
58
  # -------------------------------- inner step -------------------------------- #
89
- b = as_tensorlist(grad)
90
- if 'inner' in self.children:
91
- b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
59
+ objective = self.inner_step("inner", objective, must_exist=False)
60
+ b = TensorList(objective.get_updates())
92
61
 
93
62
  # ---------------------------------- run cg ---------------------------------- #
94
63
  x0 = None
@@ -112,7 +81,7 @@ class NewtonSolver(Module):
112
81
  solver = self.global_state['solver']
113
82
 
114
83
  def lstsq_closure(backward=True):
115
- Hx = H_mm(x).detach()
84
+ Hx = H_mv(x).detach()
116
85
  # loss = (Hx-b).pow(2).global_mean()
117
86
  # if backward:
118
87
  # solver.zero_grad()
@@ -122,7 +91,7 @@ class NewtonSolver(Module):
122
91
  loss = residual.pow(2).global_mean()
123
92
  if backward:
124
93
  with torch.no_grad():
125
- H_residual = H_mm(residual)
94
+ H_residual = H_mv(residual)
126
95
  n = residual.global_numel()
127
96
  x.set_grad_((2.0 / n) * H_residual)
128
97
 
@@ -143,8 +112,8 @@ class NewtonSolver(Module):
143
112
  assert x0 is not None
144
113
  x0.copy_(x)
145
114
 
146
- var.update = x.detach()
115
+ objective.updates = x.detach()
147
116
  self._num_hvps += self._num_hvps_last_step
148
- return var
117
+ return objective
149
118
 
150
119
 
@@ -7,21 +7,21 @@ from typing import Literal
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, apply_transform
11
- from ...utils import TensorList, vec_to_tensors
10
+ from ...core import Chainable, Transform, step
11
+ from ...linalg.linear_operator import Dense
12
+ from ...utils import TensorList, vec_to_tensors_
12
13
  from ...utils.derivatives import (
13
14
  flatten_jacobian,
14
15
  jacobian_wrt,
15
16
  )
16
17
  from ..second_order.newton import (
17
- _cholesky_solve,
18
- _eigh_solve,
18
+ _try_cholesky_solve,
19
19
  _least_squares_solve,
20
- _lu_solve,
20
+ _try_lu_solve,
21
21
  )
22
- from ...utils.linalg.linear_operator import Dense
23
22
 
24
- class NewtonNewton(Module):
23
+
24
+ class NewtonNewton(Transform):
25
25
  """Applies Newton-like preconditioning to Newton step.
26
26
 
27
27
  This is a method that I thought of and then it worked. Here is how it works:
@@ -33,42 +33,36 @@ class NewtonNewton(Module):
33
33
  3. Solve H2 x2 = x for x2.
34
34
 
35
35
  4. Optionally, repeat (if order is higher than 3.)
36
-
37
- Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
38
-
39
- 3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
40
- this is if pytorch can vectorize hessian computation efficiently.
41
36
  """
42
37
  def __init__(
43
38
  self,
44
39
  reg: float = 1e-6,
45
40
  order: int = 3,
46
- search_negative: bool = False,
47
41
  vectorize: bool = True,
48
- eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
42
+ update_freq: int = 1,
43
+ inner: Chainable | None = None,
49
44
  ):
50
- defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_fn=eigval_fn, search_negative=search_negative)
51
- super().__init__(defaults)
45
+ defaults = dict(order=order, reg=reg, vectorize=vectorize)
46
+ super().__init__(defaults, update_freq=update_freq, inner=inner)
52
47
 
53
48
  @torch.no_grad
54
- def update(self, var):
55
- params = TensorList(var.params)
56
- closure = var.closure
49
+ def update_states(self, objective, states, settings):
50
+ fs = settings[0]
51
+
52
+ params = TensorList(objective.params)
53
+ closure = objective.closure
57
54
  if closure is None: raise RuntimeError('NewtonNewton requires closure')
58
55
 
59
- settings = self.settings[params[0]]
60
- reg = settings['reg']
61
- vectorize = settings['vectorize']
62
- order = settings['order']
63
- search_negative = settings['search_negative']
64
- eigval_fn = settings['eigval_fn']
56
+ reg = fs['reg']
57
+ vectorize = fs['vectorize']
58
+ order = fs['order']
65
59
 
66
60
  # ------------------------ calculate grad and hessian ------------------------ #
67
- Hs = []
61
+ P = None
68
62
  with torch.enable_grad():
69
- loss = var.loss = var.loss_approx = closure(False)
63
+ loss = objective.loss = objective.loss_approx = closure(False)
70
64
  g_list = torch.autograd.grad(loss, params, create_graph=True)
71
- var.grad = list(g_list)
65
+ objective.grads = list(g_list)
72
66
 
73
67
  xp = torch.cat([t.ravel() for t in g_list])
74
68
  I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
@@ -79,27 +73,30 @@ class NewtonNewton(Module):
79
73
  with torch.no_grad() if is_last else nullcontext():
80
74
  H = flatten_jacobian(H_list)
81
75
  if reg != 0: H = H + I * reg
82
- Hs.append(H)
76
+ if P is None: P = H
77
+ else: P = P @ H
78
+
79
+ if not is_last:
80
+ x = _try_cholesky_solve(H, xp)
81
+ if x is None: x = _try_lu_solve(H, xp)
82
+ if x is None: x = _least_squares_solve(H, xp)
83
+ xp = x.squeeze()
84
+
85
+ self.global_state["P"] = P
86
+
87
+ @torch.no_grad
88
+ def apply_states(self, objective, states, settings):
89
+ updates = objective.get_updates()
90
+ P = self.global_state['P']
91
+ b = torch.cat([t.ravel() for t in updates])
83
92
 
84
- x = None
85
- if search_negative or (is_last and eigval_fn is not None):
86
- x = _eigh_solve(H, xp, eigval_fn, search_negative=search_negative)
87
- if x is None: x = _cholesky_solve(H, xp)
88
- if x is None: x = _lu_solve(H, xp)
89
- if x is None: x = _least_squares_solve(H, xp)
90
- xp = x.squeeze()
93
+ sol = _try_cholesky_solve(P, b)
94
+ if sol is None: sol = _try_lu_solve(P, b)
95
+ if sol is None: sol = _least_squares_solve(P, b)
91
96
 
92
- self.global_state["Hs"] = Hs
93
- self.global_state['xp'] = xp.nan_to_num_(0,0,0)
97
+ vec_to_tensors_(sol, updates)
98
+ return objective
94
99
 
95
100
  @torch.no_grad
96
- def apply(self, var):
97
- params = var.params
98
- xp = self.global_state['xp']
99
- var.update = vec_to_tensors(xp, params)
100
- return var
101
-
102
- def get_H(self, var):
103
- Hs = self.global_state["Hs"]
104
- if len(Hs) == 1: return Dense(Hs[0])
105
- return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
101
+ def get_H(self, objective=...):
102
+ return Dense(self.global_state["P"])
@@ -1,28 +1,28 @@
1
1
  import torch
2
2
 
3
- from ...core import Target, Transform
3
+ from ...core import TensorTransform
4
4
  from ...utils import TensorList, unpack_states, unpack_dicts
5
5
 
6
- class ReduceOutwardLR(Transform):
6
+ class ReduceOutwardLR(TensorTransform):
7
7
  """When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
8
8
 
9
9
  This means updates that move weights towards zero have higher learning rates.
10
10
 
11
- .. warning::
11
+ Warning:
12
12
  This sounded good but after testing turns out it sucks.
13
13
  """
14
- def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
14
+ def __init__(self, mul = 0.5, use_grad=False, invert=False):
15
15
  defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
16
- super().__init__(defaults, uses_grad=use_grad, target=target)
16
+ super().__init__(defaults, uses_grad=use_grad)
17
17
 
18
18
  @torch.no_grad
19
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
19
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
20
20
  params = TensorList(params)
21
21
  tensors = TensorList(tensors)
22
22
 
23
23
  mul = [s['mul'] for s in settings]
24
24
  s = settings[0]
25
- use_grad = s['use_grad']
25
+ use_grad = self._uses_grad
26
26
  invert = s['invert']
27
27
 
28
28
  if use_grad: cur = grads
@@ -3,10 +3,9 @@ from typing import Literal, overload
3
3
  import torch
4
4
  from scipy.sparse.linalg import LinearOperator, gcrotmk
5
5
 
6
- from ...core import Chainable, Module, apply_transform
7
- from ...utils import NumberList, TensorList, as_tensorlist, generic_vector_norm, vec_to_tensors
8
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
9
- from ...utils.linalg.solve import cg, minres
6
+ from ...core import Chainable, Module, step
7
+ from ...utils import TensorList, vec_to_tensors
8
+ from ...utils.derivatives import hvp_fd_central, hvp_fd_forward
10
9
 
11
10
 
12
11
  class ScipyNewtonCG(Module):
@@ -14,7 +13,7 @@ class ScipyNewtonCG(Module):
14
13
  def __init__(
15
14
  self,
16
15
  solver = gcrotmk,
17
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
16
+ hvp_method: Literal["fd_forward", "fd_central", "autograd"] = "autograd",
18
17
  h: float = 1e-3,
19
18
  warm_start=False,
20
19
  inner: Chainable | None = None,
@@ -33,47 +32,47 @@ class ScipyNewtonCG(Module):
33
32
  self._kwargs = kwargs
34
33
 
35
34
  @torch.no_grad
36
- def step(self, var):
37
- params = TensorList(var.params)
38
- closure = var.closure
35
+ def apply(self, objective):
36
+ params = TensorList(objective.params)
37
+ closure = objective.closure
39
38
  if closure is None: raise RuntimeError('NewtonCG requires closure')
40
39
 
41
- settings = self.settings[params[0]]
42
- hvp_method = settings['hvp_method']
43
- solver = settings['solver']
44
- h = settings['h']
45
- warm_start = settings['warm_start']
40
+ fs = self.settings[params[0]]
41
+ hvp_method = fs['hvp_method']
42
+ solver = fs['solver']
43
+ h = fs['h']
44
+ warm_start = fs['warm_start']
46
45
 
47
46
  self._num_hvps_last_step = 0
48
47
  # ---------------------- Hessian vector product function --------------------- #
49
48
  device = params[0].device; dtype=params[0].dtype
50
49
  if hvp_method == 'autograd':
51
- grad = var.get_grad(create_graph=True)
50
+ grad = objective.get_grads(create_graph=True)
52
51
 
53
52
  def H_mm(x_np):
54
53
  self._num_hvps_last_step += 1
55
54
  x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
56
55
  with torch.enable_grad():
57
- Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
56
+ Hvp = TensorList(torch.autograd.grad(grad, params, x, retain_graph=True))
58
57
  return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
59
58
 
60
59
  else:
61
60
 
62
61
  with torch.enable_grad():
63
- grad = var.get_grad()
62
+ grad = objective.get_grads()
64
63
 
65
64
  if hvp_method == 'forward':
66
65
  def H_mm(x_np):
67
66
  self._num_hvps_last_step += 1
68
67
  x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
69
- Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
68
+ Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad)[1])
70
69
  return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
71
70
 
72
71
  elif hvp_method == 'central':
73
72
  def H_mm(x_np):
74
73
  self._num_hvps_last_step += 1
75
74
  x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
76
- Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
75
+ Hvp = TensorList(hvp_fd_central(closure, params, x, h=h)[1])
77
76
  return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
78
77
 
79
78
  else:
@@ -83,10 +82,8 @@ class ScipyNewtonCG(Module):
83
82
  H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
84
83
 
85
84
  # -------------------------------- inner step -------------------------------- #
86
- b = var.get_update()
87
- if 'inner' in self.children:
88
- b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
89
- b = as_tensorlist(b)
85
+ objective = self.inner_step("inner", objective, must_exist=False)
86
+ b = TensorList(objective.get_updates())
90
87
 
91
88
  # ---------------------------------- run cg ---------------------------------- #
92
89
  x0 = None
@@ -98,8 +95,8 @@ class ScipyNewtonCG(Module):
98
95
  if warm_start:
99
96
  self.global_state['x_prev'] = x_np
100
97
 
101
- var.update = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
98
+ objective.updates = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
102
99
 
103
100
  self._num_hvps += self._num_hvps_last_step
104
- return var
101
+ return objective
105
102