torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -6,10 +6,10 @@ from typing import Literal
6
6
  import torch
7
7
 
8
8
  from ...core import Target, Transform
9
- from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
9
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
10
10
  from ..functional import ema_, ema_sq_, sqrt_ema_sq_
11
- from .ema import EMASquared, SqrtEMASquared
12
- from .momentum import nag_
11
+ from ..momentum.momentum import nag_
12
+ from ..ops.higher_level import EMASquared, SqrtEMASquared
13
13
 
14
14
 
15
15
  def precentered_ema_sq_(
@@ -49,7 +49,7 @@ class PrecenteredEMASquared(Transform):
49
49
  super().__init__(defaults, uses_grad=False, target=target)
50
50
 
51
51
  @torch.no_grad
52
- def apply(self, tensors, params, grads, loss, states, settings):
52
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
53
53
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
54
54
 
55
55
  beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
@@ -154,44 +154,7 @@ class CoordinateMomentum(Transform):
154
154
  super().__init__(defaults, uses_grad=False, target=target)
155
155
 
156
156
  @torch.no_grad
157
- def apply(self, tensors, params, grads, loss, states, settings):
157
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
158
158
  p = NumberList(s['p'] for s in settings)
159
159
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
160
160
  return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
161
-
162
-
163
- # def multiplicative_momentum_(
164
- # tensors_: TensorList,
165
- # velocity_: TensorList,
166
- # momentum: float | NumberList,
167
- # dampening: float | NumberList,
168
- # normalize_velocity: bool = True,
169
- # abs: bool = False,
170
- # lerp: bool = False,
171
- # ):
172
- # """
173
- # abs: if True, tracks momentum of absolute magnitudes.
174
-
175
- # returns `tensors_`.
176
- # """
177
- # tensors_into_velocity = tensors_.abs() if abs else tensors_
178
- # ema_(tensors_into_velocity, exp_avg_=velocity_, beta=momentum, dampening=0, lerp=lerp)
179
-
180
- # if normalize_velocity: velocity_ = velocity_ / velocity_.std().add_(1e-8)
181
- # return tensors_.mul_(velocity_.lazy_mul(1-dampening) if abs else velocity_.abs().lazy_mul_(1-dampening))
182
-
183
-
184
- # class MultiplicativeMomentum(Transform):
185
- # """sucks"""
186
- # def __init__(self, momentum: float = 0.9, dampening: float = 0,normalize_velocity: bool = True, abs: bool = False, lerp: bool = False):
187
- # defaults = dict(momentum=momentum, dampening=dampening, normalize_velocity=normalize_velocity,abs=abs, lerp=lerp)
188
- # super().__init__(defaults, uses_grad=False)
189
-
190
- # @torch.no_grad
191
- # def apply(self, tensors, params, grads, loss, states, settings):
192
- # momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
193
- # abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
194
- # velocity = self.get_state('velocity', params=params, cls=TensorList)
195
- # return multiplicative_momentum_(TensorList(target), velocity_=velocity, momentum=momentum, dampening=dampening,
196
- # normalize_velocity=normalize_velocity,abs=abs,lerp=lerp)
197
-
@@ -3,28 +3,36 @@ from typing import Any, Literal, overload
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, apply_transform, Modular
6
+ from ...core import Chainable, Modular, Module, apply_transform
7
7
  from ...utils import TensorList, as_tensorlist
8
- from ...utils.derivatives import hvp
8
+ from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
9
9
  from ..quasi_newton import LBFGS
10
10
 
11
+
11
12
  class NewtonSolver(Module):
12
- """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
13
+ """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
13
14
  def __init__(
14
15
  self,
15
16
  solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
16
17
  maxiter=None,
17
- tol=1e-3,
18
+ maxiter1=None,
19
+ tol:float | None=1e-3,
18
20
  reg: float = 0,
19
21
  warm_start=True,
22
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
23
+ reset_solver: bool = False,
24
+ h: float= 1e-3,
20
25
  inner: Chainable | None = None,
21
26
  ):
22
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
27
+ defaults = dict(tol=tol, h=h,reset_solver=reset_solver, maxiter=maxiter, maxiter1=maxiter1, reg=reg, warm_start=warm_start, solver=solver, hvp_method=hvp_method)
23
28
  super().__init__(defaults,)
24
29
 
25
30
  if inner is not None:
26
31
  self.set_child('inner', inner)
27
32
 
33
+ self._num_hvps = 0
34
+ self._num_hvps_last_step = 0
35
+
28
36
  @torch.no_grad
29
37
  def step(self, var):
30
38
  params = TensorList(var.params)
@@ -34,19 +42,49 @@ class NewtonSolver(Module):
34
42
  settings = self.settings[params[0]]
35
43
  solver_cls = settings['solver']
36
44
  maxiter = settings['maxiter']
45
+ maxiter1 = settings['maxiter1']
37
46
  tol = settings['tol']
38
47
  reg = settings['reg']
48
+ hvp_method = settings['hvp_method']
39
49
  warm_start = settings['warm_start']
50
+ h = settings['h']
51
+ reset_solver = settings['reset_solver']
40
52
 
53
+ self._num_hvps_last_step = 0
41
54
  # ---------------------- Hessian vector product function --------------------- #
42
- grad = var.get_grad(create_graph=True)
55
+ if hvp_method == 'autograd':
56
+ grad = var.get_grad(create_graph=True)
43
57
 
44
- def H_mm(x):
45
- with torch.enable_grad():
46
- Hvp = TensorList(hvp(params, grad, x, create_graph=True))
58
+ def H_mm(x):
59
+ self._num_hvps_last_step += 1
60
+ with torch.enable_grad():
61
+ Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
47
62
  if reg != 0: Hvp = Hvp + (x*reg)
48
63
  return Hvp
49
64
 
65
+ else:
66
+
67
+ with torch.enable_grad():
68
+ grad = var.get_grad()
69
+
70
+ if hvp_method == 'forward':
71
+ def H_mm(x):
72
+ self._num_hvps_last_step += 1
73
+ Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
74
+ if reg != 0: Hvp = Hvp + (x*reg)
75
+ return Hvp
76
+
77
+ elif hvp_method == 'central':
78
+ def H_mm(x):
79
+ self._num_hvps_last_step += 1
80
+ Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
81
+ if reg != 0: Hvp = Hvp + (x*reg)
82
+ return Hvp
83
+
84
+ else:
85
+ raise ValueError(hvp_method)
86
+
87
+
50
88
  # -------------------------------- inner step -------------------------------- #
51
89
  b = as_tensorlist(grad)
52
90
  if 'inner' in self.children:
@@ -58,23 +96,46 @@ class NewtonSolver(Module):
58
96
  if x0 is None: x = b.zeros_like().requires_grad_(True)
59
97
  else: x = x0.clone().requires_grad_(True)
60
98
 
61
- solver = solver_cls(x)
99
+
100
+ if 'solver' not in self.global_state:
101
+ if maxiter1 is not None: maxiter = maxiter1
102
+ solver = self.global_state['solver'] = solver_cls(x)
103
+ self.global_state['x'] = x
104
+
105
+ else:
106
+ if reset_solver:
107
+ solver = self.global_state['solver'] = solver_cls(x)
108
+ else:
109
+ solver_params = self.global_state['x']
110
+ solver_params.set_(x)
111
+ x = solver_params
112
+ solver = self.global_state['solver']
113
+
62
114
  def lstsq_closure(backward=True):
63
- Hx = H_mm(x)
64
- loss = (Hx-b).pow(2).global_mean()
115
+ Hx = H_mm(x).detach()
116
+ # loss = (Hx-b).pow(2).global_mean()
117
+ # if backward:
118
+ # solver.zero_grad()
119
+ # loss.backward(inputs=x)
120
+
121
+ residual = Hx - b
122
+ loss = residual.pow(2).global_mean()
65
123
  if backward:
66
- solver.zero_grad()
67
- loss.backward(inputs=x)
124
+ with torch.no_grad():
125
+ H_residual = H_mm(residual)
126
+ n = residual.global_numel()
127
+ x.set_grad_((2.0 / n) * H_residual)
128
+
68
129
  return loss
69
130
 
70
131
  if maxiter is None: maxiter = b.global_numel()
71
132
  loss = None
72
- initial_loss = lstsq_closure(False)
73
- if initial_loss > tol:
133
+ initial_loss = lstsq_closure(False) if tol is not None else None # skip unnecessary closure if tol is None
134
+ if initial_loss is None or initial_loss > torch.finfo(b[0].dtype).eps:
74
135
  for i in range(maxiter):
75
136
  loss = solver.step(lstsq_closure)
76
137
  assert loss is not None
77
- if min(loss, loss/initial_loss) < tol: break
138
+ if initial_loss is not None and loss/initial_loss < tol: break
78
139
 
79
140
  # print(f'{loss = }')
80
141
 
@@ -83,6 +144,7 @@ class NewtonSolver(Module):
83
144
  x0.copy_(x)
84
145
 
85
146
  var.update = x.detach()
147
+ self._num_hvps += self._num_hvps_last_step
86
148
  return var
87
149
 
88
150
 
@@ -10,20 +10,21 @@ import torch
10
10
  from ...core import Chainable, Module, apply_transform
11
11
  from ...utils import TensorList, vec_to_tensors
12
12
  from ...utils.derivatives import (
13
- hessian_list_to_mat,
13
+ flatten_jacobian,
14
14
  jacobian_wrt,
15
15
  )
16
16
  from ..second_order.newton import (
17
- cholesky_solve,
18
- eigh_solve,
19
- least_squares_solve,
20
- lu_solve,
17
+ _cholesky_solve,
18
+ _eigh_solve,
19
+ _least_squares_solve,
20
+ _lu_solve,
21
21
  )
22
-
22
+ from ...utils.linalg.linear_operator import Dense
23
23
 
24
24
  class NewtonNewton(Module):
25
- """
26
- Method that I thought of and then it worked.
25
+ """Applies Newton-like preconditioning to Newton step.
26
+
27
+ This is a method that I thought of and then it worked. Here is how it works:
27
28
 
28
29
  1. Calculate newton step by solving Hx=g
29
30
 
@@ -34,6 +35,9 @@ class NewtonNewton(Module):
34
35
  4. Optionally, repeat (if order is higher than 3.)
35
36
 
36
37
  Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
38
+
39
+ 3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
40
+ this is if pytorch can vectorize hessian computation efficiently.
37
41
  """
38
42
  def __init__(
39
43
  self,
@@ -47,10 +51,10 @@ class NewtonNewton(Module):
47
51
  super().__init__(defaults)
48
52
 
49
53
  @torch.no_grad
50
- def step(self, var):
54
+ def update(self, var):
51
55
  params = TensorList(var.params)
52
56
  closure = var.closure
53
- if closure is None: raise RuntimeError('NewtonCG requires closure')
57
+ if closure is None: raise RuntimeError('NewtonNewton requires closure')
54
58
 
55
59
  settings = self.settings[params[0]]
56
60
  reg = settings['reg']
@@ -60,6 +64,7 @@ class NewtonNewton(Module):
60
64
  eigval_tfm = settings['eigval_tfm']
61
65
 
62
66
  # ------------------------ calculate grad and hessian ------------------------ #
67
+ Hs = []
63
68
  with torch.enable_grad():
64
69
  loss = var.loss = var.loss_approx = closure(False)
65
70
  g_list = torch.autograd.grad(loss, params, create_graph=True)
@@ -72,17 +77,29 @@ class NewtonNewton(Module):
72
77
  is_last = o == order
73
78
  H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
74
79
  with torch.no_grad() if is_last else nullcontext():
75
- H = hessian_list_to_mat(H_list)
80
+ H = flatten_jacobian(H_list)
76
81
  if reg != 0: H = H + I * reg
82
+ Hs.append(H)
77
83
 
78
84
  x = None
79
85
  if search_negative or (is_last and eigval_tfm is not None):
80
- x = eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
81
- if x is None: x = cholesky_solve(H, xp)
82
- if x is None: x = lu_solve(H, xp)
83
- if x is None: x = least_squares_solve(H, xp)
86
+ x = _eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
87
+ if x is None: x = _cholesky_solve(H, xp)
88
+ if x is None: x = _lu_solve(H, xp)
89
+ if x is None: x = _least_squares_solve(H, xp)
84
90
  xp = x.squeeze()
85
91
 
92
+ self.global_state["Hs"] = Hs
93
+ self.global_state['xp'] = xp.nan_to_num_(0,0,0)
94
+
95
+ @torch.no_grad
96
+ def apply(self, var):
97
+ params = var.params
98
+ xp = self.global_state['xp']
86
99
  var.update = vec_to_tensors(xp, params)
87
100
  return var
88
101
 
102
+ def get_H(self, var):
103
+ Hs = self.global_state["Hs"]
104
+ if len(Hs) == 1: return Dense(Hs[0])
105
+ return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
@@ -4,19 +4,19 @@ from ...core import Target, Transform
4
4
  from ...utils import TensorList, unpack_states, unpack_dicts
5
5
 
6
6
  class ReduceOutwardLR(Transform):
7
- """
8
- When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
7
+ """When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
9
8
 
10
9
  This means updates that move weights towards zero have higher learning rates.
11
10
 
12
- A note on this is that it sounded good but its really bad in practice.
11
+ .. warning::
12
+ This sounded good but after testing turns out it sucks.
13
13
  """
14
14
  def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
15
15
  defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
16
16
  super().__init__(defaults, uses_grad=use_grad, target=target)
17
17
 
18
18
  @torch.no_grad
19
- def apply(self, tensors, params, grads, loss, states, settings):
19
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
20
20
  params = TensorList(params)
21
21
  tensors = TensorList(tensors)
22
22
 
@@ -0,0 +1,105 @@
1
+ from typing import Literal, overload
2
+
3
+ import torch
4
+ from scipy.sparse.linalg import LinearOperator, gcrotmk
5
+
6
+ from ...core import Chainable, Module, apply_transform
7
+ from ...utils import NumberList, TensorList, as_tensorlist, generic_vector_norm, vec_to_tensors
8
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
9
+ from ...utils.linalg.solve import cg, minres
10
+
11
+
12
+ class ScipyNewtonCG(Module):
13
+ """NewtonCG with scipy solvers (any from scipy.sparse.linalg)"""
14
+ def __init__(
15
+ self,
16
+ solver = gcrotmk,
17
+ hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
+ h: float = 1e-3,
19
+ warm_start=False,
20
+ inner: Chainable | None = None,
21
+ kwargs: dict | None = None,
22
+ ):
23
+ defaults = dict(hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
24
+ super().__init__(defaults,)
25
+
26
+ if inner is not None:
27
+ self.set_child('inner', inner)
28
+
29
+ self._num_hvps = 0
30
+ self._num_hvps_last_step = 0
31
+
32
+ if kwargs is None: kwargs = {}
33
+ self._kwargs = kwargs
34
+
35
+ @torch.no_grad
36
+ def step(self, var):
37
+ params = TensorList(var.params)
38
+ closure = var.closure
39
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
40
+
41
+ settings = self.settings[params[0]]
42
+ hvp_method = settings['hvp_method']
43
+ solver = settings['solver']
44
+ h = settings['h']
45
+ warm_start = settings['warm_start']
46
+
47
+ self._num_hvps_last_step = 0
48
+ # ---------------------- Hessian vector product function --------------------- #
49
+ device = params[0].device; dtype=params[0].dtype
50
+ if hvp_method == 'autograd':
51
+ grad = var.get_grad(create_graph=True)
52
+
53
+ def H_mm(x_np):
54
+ self._num_hvps_last_step += 1
55
+ x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
56
+ with torch.enable_grad():
57
+ Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
58
+ return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
59
+
60
+ else:
61
+
62
+ with torch.enable_grad():
63
+ grad = var.get_grad()
64
+
65
+ if hvp_method == 'forward':
66
+ def H_mm(x_np):
67
+ self._num_hvps_last_step += 1
68
+ x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
69
+ Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
70
+ return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
71
+
72
+ elif hvp_method == 'central':
73
+ def H_mm(x_np):
74
+ self._num_hvps_last_step += 1
75
+ x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
76
+ Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
77
+ return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
78
+
79
+ else:
80
+ raise ValueError(hvp_method)
81
+
82
+ ndim = sum(p.numel() for p in params)
83
+ H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
84
+
85
+ # -------------------------------- inner step -------------------------------- #
86
+ b = var.get_update()
87
+ if 'inner' in self.children:
88
+ b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
89
+ b = as_tensorlist(b)
90
+
91
+ # ---------------------------------- run cg ---------------------------------- #
92
+ x0 = None
93
+ if warm_start: x0 = self.global_state.get('x_prev', None) # initialized to 0 which is default anyway
94
+
95
+ x_np = solver(H, b.to_vec().nan_to_num().numpy(force=True), x0=x0, **self._kwargs)
96
+ if isinstance(x_np, tuple): x_np = x_np[0]
97
+
98
+ if warm_start:
99
+ self.global_state['x_prev'] = x_np
100
+
101
+ var.update = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
102
+
103
+ self._num_hvps += self._num_hvps_last_step
104
+ return var
105
+
@@ -5,36 +5,19 @@ import torch
5
5
 
6
6
  from ...core import Chainable
7
7
  from ...utils import vec_to_tensors, TensorList
8
- from ..optimizers.shampoo import _merge_small_dims
9
- from .projection import Projection
8
+ from ..adaptive.shampoo import _merge_small_dims
9
+ from ..projections import ProjectionBase
10
10
 
11
11
 
12
- class VectorProjection(Projection):
13
- """
14
- flattens and concatenates all parameters into a vector
15
- """
16
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
17
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
18
12
 
19
- @torch.no_grad
20
- def project(self, tensors, var, current):
21
- return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
22
-
23
- @torch.no_grad
24
- def unproject(self, tensors, var, current):
25
- return vec_to_tensors(vec=tensors[0], reference=var.params)
26
-
27
-
28
-
29
- class TensorizeProjection(Projection):
13
+ class TensorizeProjection(ProjectionBase):
30
14
  """flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
31
15
  def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
32
16
  defaults = dict(max_side=max_side)
33
17
  super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
34
18
 
35
19
  @torch.no_grad
36
- def project(self, tensors, var, current):
37
- params = var.params
20
+ def project(self, tensors, params, grads, loss, states, settings, current):
38
21
  max_side = self.settings[params[0]]['max_side']
39
22
  num_elems = sum(t.numel() for t in tensors)
40
23
 
@@ -60,23 +43,23 @@ class TensorizeProjection(Projection):
60
43
  return [vec.view(dims)]
61
44
 
62
45
  @torch.no_grad
63
- def unproject(self, tensors, var, current):
46
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
64
47
  remainder = self.global_state['remainder']
65
48
  # warnings.warn(f'{tensors[0].shape = }')
66
- vec = tensors[0].view(-1)
49
+ vec = projected_tensors[0].view(-1)
67
50
  if remainder > 0: vec = vec[:-remainder]
68
- return vec_to_tensors(vec, var.params)
51
+ return vec_to_tensors(vec, params)
69
52
 
70
- class BlockPartition(Projection):
53
+ class BlockPartition(ProjectionBase):
71
54
  """splits parameters into blocks (for now flatttens them and chunks)"""
72
55
  def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
73
56
  defaults = dict(max_size=max_size, batched=batched)
74
57
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
75
58
 
76
59
  @torch.no_grad
77
- def project(self, tensors, var, current):
60
+ def project(self, tensors, params, grads, loss, states, settings, current):
78
61
  partitioned = []
79
- for p,t in zip(var.params, tensors):
62
+ for p,t in zip(params, tensors):
80
63
  settings = self.settings[p]
81
64
  max_size = settings['max_size']
82
65
  n = t.numel()
@@ -101,10 +84,10 @@ class BlockPartition(Projection):
101
84
  return partitioned
102
85
 
103
86
  @torch.no_grad
104
- def unproject(self, tensors, var, current):
105
- ti = iter(tensors)
87
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
88
+ ti = iter(projected_tensors)
106
89
  unprojected = []
107
- for p in var.params:
90
+ for p in params:
108
91
  settings = self.settings[p]
109
92
  n = p.numel()
110
93
 
@@ -124,28 +107,3 @@ class BlockPartition(Projection):
124
107
 
125
108
  return unprojected
126
109
 
127
-
128
- class TensorNormsProjection(Projection):
129
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
130
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
131
-
132
- @torch.no_grad
133
- def project(self, tensors, var, current):
134
- orig = self.get_state(var.params, f'{current}_orig')
135
- torch._foreach_copy_(orig, tensors)
136
-
137
- norms = torch._foreach_norm(tensors)
138
- self.get_state(var.params, f'{current}_orig_norms', cls=TensorList).set_(norms)
139
-
140
- return [torch.stack(norms)]
141
-
142
- @torch.no_grad
143
- def unproject(self, tensors, var, current):
144
- orig = self.get_state(var.params, f'{current}_orig')
145
- orig_norms = torch.stack(self.get_state(var.params, f'{current}_orig_norms'))
146
- target_norms = tensors[0]
147
-
148
- orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
149
-
150
- torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
151
- return orig
@@ -7,10 +7,19 @@ storage is always indicated in the docstring.
7
7
 
8
8
  Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
9
9
  """
10
+ from collections.abc import Callable
11
+ from typing import overload
10
12
 
11
- from collections.abc import Callable, Sequence
13
+ import torch
12
14
 
13
- from ..utils import NumberList, TensorList
15
+ from ..utils import (
16
+ NumberList,
17
+ TensorList,
18
+ generic_finfo_eps,
19
+ generic_max,
20
+ generic_sum,
21
+ tofloat,
22
+ )
14
23
 
15
24
  inf = float('inf')
16
25
 
@@ -86,10 +95,10 @@ def root(tensors_:TensorList, p:float, inplace: bool):
86
95
  if p == 1: return tensors_.abs_()
87
96
  if p == 2: return tensors_.sqrt_()
88
97
  return tensors_.pow_(1/p)
89
- else:
90
- if p == 1: return tensors_.abs()
91
- if p == 2: return tensors_.sqrt()
92
- return tensors_.pow(1/p)
98
+
99
+ if p == 1: return tensors_.abs()
100
+ if p == 2: return tensors_.sqrt()
101
+ return tensors_.pow(1/p)
93
102
 
94
103
 
95
104
  def ema_(
@@ -206,4 +215,41 @@ def sqrt_centered_ema_sq_(
206
215
  ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
207
216
  )
208
217
 
218
+ def initial_step_size(tensors: torch.Tensor | TensorList, eps=None) -> float:
219
+ """initial scaling taken from pytorch L-BFGS to avoid requiring a lot of line search iterations,
220
+ this version is safer and makes sure largest value isn't smaller than epsilon."""
221
+ tensors_abs = tensors.abs()
222
+ tensors_sum = generic_sum(tensors_abs)
223
+ tensors_max = generic_max(tensors_abs)
224
+
225
+ feps = generic_finfo_eps(tensors)
226
+ if eps is None: eps = feps
227
+ else: eps = max(eps, feps)
228
+
229
+ # scale should not make largest value smaller than epsilon
230
+ min = eps / tensors_max
231
+ if min >= 1: return 1.0
232
+
233
+ scale = 1 / tensors_sum
234
+ scale = scale.clip(min=min.item(), max=1)
235
+ return scale.item()
236
+
237
+
238
+ def epsilon_step_size(tensors: torch.Tensor | TensorList, alpha=1e-7) -> float:
239
+ """makes sure largest value isn't smaller than epsilon."""
240
+ tensors_abs = tensors.abs()
241
+ tensors_max = generic_max(tensors_abs)
242
+ if tensors_max < alpha: return 1.0
243
+
244
+ if tensors_max < 1: alpha = alpha / tensors_max
245
+ return tofloat(alpha)
246
+
247
+
248
+
249
+ def safe_clip(x: torch.Tensor, min=None):
250
+ """makes sure absolute value of scalar tensor x is not smaller than min"""
251
+ assert x.numel() == 1, x.shape
252
+ if min is None: min = torch.finfo(x.dtype).tiny * 2
209
253
 
254
+ if x.abs() < min: return x.new_full(x.size(), min).copysign(x)
255
+ return x