torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,85 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import TensorwiseTransform
6
+
7
+
8
+ def eigh_solve(H: torch.Tensor, g: torch.Tensor):
9
+ try:
10
+ L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
11
+ return Q @ ((Q.mH @ g) / L)
12
+ except torch.linalg.LinAlgError:
13
+ return None
14
+
15
+
16
+ class HNewton(TensorwiseTransform):
17
+ """This treats gradient differences as Hvps with vectors being parameter differences, using past gradients that are close to each other. Basically this is another limited memory quasi newton method to test.
18
+
19
+ .. warning::
20
+ Experimental.
21
+
22
+ """
23
+ def __init__(self, history_size: int, window_size: int, reg: float=0, tol: float = 1e-8, concat_params:bool=True, inner=None):
24
+ defaults = dict(history_size=history_size, window_size=window_size, reg=reg, tol=tol)
25
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner)
26
+
27
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
28
+
29
+ history_size = setting['history_size']
30
+
31
+ if 'param_history' not in state:
32
+ state['param_history'] = deque(maxlen=history_size)
33
+ state['grad_history'] = deque(maxlen=history_size)
34
+
35
+ param_history: deque = state['param_history']
36
+ grad_history: deque = state['grad_history']
37
+ param_history.append(param.ravel())
38
+ grad_history.append(tensor.ravel())
39
+
40
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
41
+ window_size = setting['window_size']
42
+ reg = setting['reg']
43
+ tol = setting['tol']
44
+
45
+ param_history: deque = state['param_history']
46
+ grad_history: deque = state['grad_history']
47
+ g = tensor.ravel()
48
+
49
+ n = len(param_history)
50
+ s_list = []
51
+ y_list = []
52
+
53
+ for i in range(n):
54
+ for j in range(i):
55
+ if i - j <= window_size:
56
+ p_i, g_i = param_history[i], grad_history[i]
57
+ p_j, g_j = param_history[j], grad_history[j]
58
+ s = p_i - p_j # vec in hvp
59
+ y = g_i - g_j # hvp
60
+ if s.dot(y) > tol:
61
+ s_list.append(s)
62
+ y_list.append(y)
63
+
64
+ if len(s_list) < 1:
65
+ scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
66
+ tensor.mul_(scale)
67
+ return tensor
68
+
69
+ S = torch.stack(s_list, 1)
70
+ Y = torch.stack(y_list, 1)
71
+
72
+ B = S.T @ Y
73
+ if reg != 0: B.add_(torch.eye(B.size(0), device=B.device, dtype=B.dtype).mul_(reg))
74
+ g_proj = g @ S
75
+
76
+ newton_proj, info = torch.linalg.solve_ex(B, g_proj) # pylint:disable=not-callable
77
+ if info != 0:
78
+ newton_proj = -torch.linalg.lstsq(B, g_proj).solution # pylint:disable=not-callable
79
+ newton = S @ newton_proj
80
+ return newton.view_as(tensor)
81
+
82
+
83
+ # scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
84
+ # tensor.mul_(scale)
85
+ # return tensor
@@ -4,8 +4,8 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
- from ....core import Chainable, Module, Transform, Vars, apply, maybe_chain
8
- from ....utils import NumberList, TensorList, as_tensorlist
7
+ from ...core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
8
+ from ...utils import NumberList, TensorList, as_tensorlist
9
9
 
10
10
 
11
11
  def _adaptive_damping(
@@ -28,7 +28,7 @@ def _adaptive_damping(
28
28
 
29
29
  def lbfgs(
30
30
  tensors_: TensorList,
31
- vars: Vars,
31
+ var: Var,
32
32
  s_history: deque[TensorList],
33
33
  y_history: deque[TensorList],
34
34
  sy_history: deque[torch.Tensor],
@@ -43,58 +43,57 @@ def lbfgs(
43
43
  if scale < 1e-5: scale = 1 / tensors_.abs().mean()
44
44
  return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
45
45
 
46
- else:
47
- # 1st loop
48
- alpha_list = []
49
- q = tensors_.clone()
50
- for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
51
- p_i = 1 / ys_i # this is also denoted as ρ (rho)
52
- alpha = p_i * s_i.dot(q)
53
- alpha_list.append(alpha)
54
- q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
55
-
56
- # calculate z
57
- # s.y/y.y is also this weird y-looking symbol I couldn't find
58
- # z is it times q
59
- # actually H0 = (s.y/y.y) * I, and z = H0 @ q
60
- z = q * (ys_k / (y_k.dot(y_k)))
61
-
62
- if z_tfm is not None:
63
- z = TensorList(apply(z_tfm, tensors=z, params=vars.params, grads=vars.grad, vars=vars))
64
-
65
- # 2nd loop
66
- for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
67
- p_i = 1 / ys_i
68
- beta_i = p_i * y_i.dot(z)
69
- z.add_(s_i, alpha = alpha_i - beta_i)
70
-
71
- return z
46
+ # 1st loop
47
+ alpha_list = []
48
+ q = tensors_.clone()
49
+ for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
50
+ p_i = 1 / ys_i # this is also denoted as ρ (rho)
51
+ alpha = p_i * s_i.dot(q)
52
+ alpha_list.append(alpha)
53
+ q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
54
+
55
+ # calculate z
56
+ # s.y/y.y is also this weird y-looking symbol I couldn't find
57
+ # z is it times q
58
+ # actually H0 = (s.y/y.y) * I, and z = H0 @ q
59
+ z = q * (ys_k / (y_k.dot(y_k)))
60
+
61
+ if z_tfm is not None:
62
+ z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
63
+
64
+ # 2nd loop
65
+ for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
66
+ p_i = 1 / ys_i
67
+ beta_i = p_i * y_i.dot(z)
68
+ z.add_(s_i, alpha = alpha_i - beta_i)
69
+
70
+ return z
72
71
 
73
72
  def _apply_tfms_into_history(
74
73
  self: Module,
75
74
  params: list[torch.Tensor],
76
- vars: Vars,
75
+ var: Var,
77
76
  update: list[torch.Tensor],
78
77
  ):
79
78
  if 'params_history_tfm' in self.children:
80
- params = apply(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
79
+ params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
81
80
 
82
81
  if 'grad_history_tfm' in self.children:
83
- update = apply(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=vars.grad, vars=vars)
82
+ update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
84
83
 
85
84
  return params, update
86
85
 
87
86
  def _apply_tfms_into_precond(
88
87
  self: Module,
89
88
  params: list[torch.Tensor],
90
- vars: Vars,
89
+ var: Var,
91
90
  update: list[torch.Tensor],
92
91
  ):
93
92
  if 'params_precond_tfm' in self.children:
94
- params = apply(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
93
+ params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
95
94
 
96
95
  if 'grad_precond_tfm' in self.children:
97
- update = apply(self.children['grad_precond_tfm'], tensors=update, params=params, grads=vars.grad, vars=vars)
96
+ update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
98
97
 
99
98
  return params, update
100
99
 
@@ -165,9 +164,9 @@ class ModularLBFGS(Module):
165
164
  self.global_state['sy_history'].clear()
166
165
 
167
166
  @torch.no_grad
168
- def step(self, vars):
169
- params = as_tensorlist(vars.params)
170
- update = as_tensorlist(vars.get_update())
167
+ def step(self, var):
168
+ params = as_tensorlist(var.params)
169
+ update = as_tensorlist(var.get_update())
171
170
  step = self.global_state.get('step', 0)
172
171
  self.global_state['step'] = step + 1
173
172
 
@@ -186,11 +185,11 @@ class ModularLBFGS(Module):
186
185
  params_h, update_h = _apply_tfms_into_history(
187
186
  self,
188
187
  params=params,
189
- vars=vars,
188
+ var=var,
190
189
  update=update,
191
190
  )
192
191
 
193
- prev_params_h, prev_grad_h = self.get_state('prev_params_h', 'prev_grad_h', params=params, cls=TensorList)
192
+ prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
194
193
 
195
194
  # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
196
195
  if step == 0:
@@ -217,16 +216,16 @@ class ModularLBFGS(Module):
217
216
  # step with inner module before applying preconditioner
218
217
  if 'update_precond_tfm' in self.children:
219
218
  update_precond_tfm = self.children['update_precond_tfm']
220
- inner_vars = update_precond_tfm.step(vars.clone(clone_update=True))
221
- vars.update_attrs_from_clone_(inner_vars)
222
- tensors = inner_vars.update
219
+ inner_var = update_precond_tfm.step(var.clone(clone_update=True))
220
+ var.update_attrs_from_clone_(inner_var)
221
+ tensors = inner_var.update
223
222
  assert tensors is not None
224
223
  else:
225
224
  tensors = update.clone()
226
225
 
227
226
  # transforms into preconditioner
228
- params_p, update_p = _apply_tfms_into_precond(self, params=params, vars=vars, update=update)
229
- prev_params_p, prev_grad_p = self.get_state('prev_params_p', 'prev_grad_p', params=params, cls=TensorList)
227
+ params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
228
+ prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
230
229
 
231
230
  if step == 0:
232
231
  s_k_p = None; y_k_p = None; ys_k_p = None
@@ -245,13 +244,13 @@ class ModularLBFGS(Module):
245
244
  # tolerance on gradient difference to avoid exploding after converging
246
245
  if tol is not None:
247
246
  if y_k_p is not None and y_k_p.abs().global_max() <= tol:
248
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
249
- return vars
247
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
248
+ return var
250
249
 
251
250
  # precondition
252
251
  dir = lbfgs(
253
252
  tensors_=as_tensorlist(tensors),
254
- vars=vars,
253
+ var=var,
255
254
  s_history=s_history,
256
255
  y_history=y_history,
257
256
  sy_history=sy_history,
@@ -260,7 +259,7 @@ class ModularLBFGS(Module):
260
259
  z_tfm=self.children.get('z_tfm', None),
261
260
  )
262
261
 
263
- vars.update = dir
262
+ var.update = dir
264
263
 
265
- return vars
264
+ return var
266
265
 
@@ -3,13 +3,13 @@ from typing import Any, Literal, overload
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, apply, Modular
6
+ from ...core import Chainable, Module, apply_transform, Modular
7
7
  from ...utils import TensorList, as_tensorlist
8
8
  from ...utils.derivatives import hvp
9
9
  from ..quasi_newton import LBFGS
10
10
 
11
11
  class NewtonSolver(Module):
12
- """Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
12
+ """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
13
13
  def __init__(
14
14
  self,
15
15
  solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
@@ -26,9 +26,9 @@ class NewtonSolver(Module):
26
26
  self.set_child('inner', inner)
27
27
 
28
28
  @torch.no_grad
29
- def step(self, vars):
30
- params = TensorList(vars.params)
31
- closure = vars.closure
29
+ def step(self, var):
30
+ params = TensorList(var.params)
31
+ closure = var.closure
32
32
  if closure is None: raise RuntimeError('NewtonCG requires closure')
33
33
 
34
34
  settings = self.settings[params[0]]
@@ -39,7 +39,7 @@ class NewtonSolver(Module):
39
39
  warm_start = settings['warm_start']
40
40
 
41
41
  # ---------------------- Hessian vector product function --------------------- #
42
- grad = vars.get_grad(create_graph=True)
42
+ grad = var.get_grad(create_graph=True)
43
43
 
44
44
  def H_mm(x):
45
45
  with torch.enable_grad():
@@ -50,11 +50,11 @@ class NewtonSolver(Module):
50
50
  # -------------------------------- inner step -------------------------------- #
51
51
  b = as_tensorlist(grad)
52
52
  if 'inner' in self.children:
53
- b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
53
+ b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
54
54
 
55
55
  # ---------------------------------- run cg ---------------------------------- #
56
56
  x0 = None
57
- if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
57
+ if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
58
58
  if x0 is None: x = b.zeros_like().requires_grad_(True)
59
59
  else: x = x0.clone().requires_grad_(True)
60
60
 
@@ -76,13 +76,13 @@ class NewtonSolver(Module):
76
76
  assert loss is not None
77
77
  if min(loss, loss/initial_loss) < tol: break
78
78
 
79
- print(f'{loss = }')
79
+ # print(f'{loss = }')
80
80
 
81
81
  if warm_start:
82
82
  assert x0 is not None
83
83
  x0.copy_(x)
84
84
 
85
- vars.update = x.detach()
86
- return vars
85
+ var.update = x.detach()
86
+ return var
87
87
 
88
88
 
@@ -0,0 +1,92 @@
1
+ import itertools
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from contextlib import nullcontext
5
+ from functools import partial
6
+ from typing import Literal
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Module, apply_transform
11
+ from ...utils import TensorList, vec_to_tensors
12
+ from ...utils.derivatives import (
13
+ hessian_list_to_mat,
14
+ jacobian_wrt,
15
+ )
16
+ from ..second_order.newton import (
17
+ cholesky_solve,
18
+ eigh_solve,
19
+ least_squares_solve,
20
+ lu_solve,
21
+ )
22
+
23
+
24
+ class NewtonNewton(Module):
25
+ """Applies Newton-like preconditioning to Newton step.
26
+
27
+ This is a method that I thought of and then it worked. Here is how it works:
28
+
29
+ 1. Calculate newton step by solving Hx=g
30
+
31
+ 2. Calculate jacobian of x wrt parameters and call it H2
32
+
33
+ 3. Solve H2 x2 = x for x2.
34
+
35
+ 4. Optionally, repeat (if order is higher than 3.)
36
+
37
+ Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
38
+
39
+ 3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
40
+ this is if pytorch can vectorize hessian computation efficiently.
41
+ """
42
+ def __init__(
43
+ self,
44
+ reg: float = 1e-6,
45
+ order: int = 3,
46
+ search_negative: bool = False,
47
+ vectorize: bool = True,
48
+ eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
49
+ ):
50
+ defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_tfm=eigval_tfm, search_negative=search_negative)
51
+ super().__init__(defaults)
52
+
53
+ @torch.no_grad
54
+ def step(self, var):
55
+ params = TensorList(var.params)
56
+ closure = var.closure
57
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
58
+
59
+ settings = self.settings[params[0]]
60
+ reg = settings['reg']
61
+ vectorize = settings['vectorize']
62
+ order = settings['order']
63
+ search_negative = settings['search_negative']
64
+ eigval_tfm = settings['eigval_tfm']
65
+
66
+ # ------------------------ calculate grad and hessian ------------------------ #
67
+ with torch.enable_grad():
68
+ loss = var.loss = var.loss_approx = closure(False)
69
+ g_list = torch.autograd.grad(loss, params, create_graph=True)
70
+ var.grad = list(g_list)
71
+
72
+ xp = torch.cat([t.ravel() for t in g_list])
73
+ I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
74
+
75
+ for o in range(2, order + 1):
76
+ is_last = o == order
77
+ H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
78
+ with torch.no_grad() if is_last else nullcontext():
79
+ H = hessian_list_to_mat(H_list)
80
+ if reg != 0: H = H + I * reg
81
+
82
+ x = None
83
+ if search_negative or (is_last and eigval_tfm is not None):
84
+ x = eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
85
+ if x is None: x = cholesky_solve(H, xp)
86
+ if x is None: x = lu_solve(H, xp)
87
+ if x is None: x = least_squares_solve(H, xp)
88
+ xp = x.squeeze()
89
+
90
+ var.update = vec_to_tensors(xp.nan_to_num_(0,0,0), params)
91
+ return var
92
+
@@ -0,0 +1,220 @@
1
+ import math
2
+ from collections.abc import Mapping
3
+ from operator import itemgetter
4
+
5
+ import torch
6
+
7
+ from ...core import Module
8
+ from ...utils import TensorList
9
+
10
+
11
+
12
+ def adaptive_tracking(
13
+ f,
14
+ f_0,
15
+ f_1,
16
+ t_0,
17
+ maxiter: int
18
+ ):
19
+
20
+ t = t_0
21
+ f_t = f(t)
22
+
23
+ # backtrack
24
+ if f_t > f_0:
25
+ if f_1 > f_0: t = min(0.5, t_0/2)
26
+ while f_t > f_0:
27
+ maxiter -= 1
28
+ if maxiter < 0: return 0, f_0
29
+ t = t/2
30
+ f_t = f(t) if t!=1 else f_1
31
+ return t, f_t
32
+
33
+ # forwardtrack
34
+ f_prev = f_t
35
+ t *= 2
36
+ f_t = f(t)
37
+ if f_prev < f_t: return t/2, f_prev
38
+ while f_prev >= f_t:
39
+ maxiter -= 1
40
+ if maxiter < 0: return t, f_t
41
+ f_prev = f_t
42
+ t *= 2
43
+ f_t = f(t)
44
+ return t/2, f_prev
45
+
46
+
47
+
48
+ class ParabolaSearch(Module):
49
+ """"""
50
+ def __init__(
51
+ self,
52
+ step_size: float = 1e-2,
53
+ adaptive: bool=True,
54
+ normalize: bool=False,
55
+ # method: str | None = None,
56
+ maxiter: int | None = 10,
57
+ # bracket=None,
58
+ # bounds=None,
59
+ # tol: float | None = None,
60
+ # options=None,
61
+ ):
62
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
63
+ defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
64
+ super().__init__(defaults)
65
+
66
+ import scipy.optimize
67
+ self.scopt = scipy.optimize
68
+
69
+
70
+ @torch.no_grad
71
+ def step(self, var):
72
+ x_0 = TensorList(var.params)
73
+ closure = var.closure
74
+ assert closure is not None
75
+ settings = self.settings[x_0[0]]
76
+ step_size = settings['step_size']
77
+ adaptive = settings['adaptive']
78
+ normalize = settings['normalize']
79
+ maxiter = settings['maxiter']
80
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
81
+
82
+ grad = TensorList(var.get_grad())
83
+ f_0 = var.get_loss(False)
84
+
85
+ scale = 1
86
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
87
+ if adaptive: scale = grad.abs().mean().clip(min=1e-8)
88
+
89
+ # make step
90
+ v_0 = grad * (step_size/scale)
91
+ x_0 -= v_0
92
+ with torch.enable_grad():
93
+ f_1 = closure()
94
+ grad = x_0.grad
95
+
96
+ x_0 += v_0
97
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
98
+ v_1 = grad * (step_size/scale)
99
+ a = v_1 - v_0
100
+
101
+ def parabolic_objective(t: float):
102
+ nonlocal x_0
103
+
104
+ step = v_0*t + 0.5*a*t**2
105
+ x_0 -= step
106
+ value = closure(False)
107
+ x_0 += step
108
+ return value.detach().cpu()
109
+
110
+ prev_t = self.global_state.get('prev_t', 2)
111
+ t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
112
+ self.global_state['prev_t'] = t
113
+
114
+ # method, bracket, bounds, tol, options, maxiter = itemgetter(
115
+ # 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
116
+
117
+ # if maxiter is not None:
118
+ # options = dict(options) if isinstance(options, Mapping) else {}
119
+ # options['maxiter'] = maxiter
120
+
121
+ # res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
122
+ # t = res.x
123
+
124
+ var.update = v_0*t + 0.5*a*t**2
125
+ return var
126
+
127
+ class CubicParabolaSearch(Module):
128
+ """"""
129
+ def __init__(
130
+ self,
131
+ step_size: float = 1e-2,
132
+ adaptive: bool=True,
133
+ normalize: bool=False,
134
+ # method: str | None = None,
135
+ maxiter: int | None = 10,
136
+ # bracket=None,
137
+ # bounds=None,
138
+ # tol: float | None = None,
139
+ # options=None,
140
+ ):
141
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
142
+ defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
143
+ super().__init__(defaults)
144
+
145
+ import scipy.optimize
146
+ self.scopt = scipy.optimize
147
+
148
+
149
+ @torch.no_grad
150
+ def step(self, var):
151
+ x_0 = TensorList(var.params)
152
+ closure = var.closure
153
+ assert closure is not None
154
+ settings = self.settings[x_0[0]]
155
+ step_size = settings['step_size']
156
+ adaptive = settings['adaptive']
157
+ maxiter = settings['maxiter']
158
+ normalize = settings['normalize']
159
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
160
+
161
+ grad = TensorList(var.get_grad())
162
+ f_0 = var.get_loss(False)
163
+
164
+ scale = 1
165
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
166
+ if adaptive: scale = grad.abs().mean().clip(min=1e-8)
167
+
168
+ # make step
169
+ v_0 = grad * (step_size/scale)
170
+ x_0 -= v_0
171
+ with torch.enable_grad():
172
+ f_1 = closure()
173
+ grad = x_0.grad
174
+
175
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
176
+ v_1 = grad * (step_size/scale)
177
+ a_0 = v_1 - v_0
178
+
179
+ # make another step
180
+ x_0 -= v_1
181
+ with torch.enable_grad():
182
+ f_2 = closure()
183
+ grad = x_0.grad
184
+
185
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
186
+ v_2 = grad * (step_size/scale)
187
+ a_1 = v_2 - v_1
188
+
189
+ j = a_1 - a_0
190
+
191
+ x_0 += v_0
192
+ x_0 += v_1
193
+
194
+ def parabolic_objective(t: float):
195
+ nonlocal x_0
196
+
197
+ step = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
198
+ x_0 -= step
199
+ value = closure(False)
200
+ x_0 += step
201
+ return value
202
+
203
+
204
+ prev_t = self.global_state.get('prev_t', 2)
205
+ t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
206
+ self.global_state['prev_t'] = t
207
+
208
+ # method, bracket, bounds, tol, options, maxiter = itemgetter(
209
+ # 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
210
+
211
+ # if maxiter is not None:
212
+ # options = dict(options) if isinstance(options, Mapping) else {}
213
+ # options['maxiter'] = maxiter
214
+
215
+ # res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
216
+ # t = res.x
217
+
218
+ var.update = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
219
+ return var
220
+