torchzero 0.3.6__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. tests/test_opts.py +1 -1
  2. tests/test_tensorlist.py +1 -1
  3. torchzero/core/preconditioner.py +12 -11
  4. torchzero/modules/experimental/__init__.py +3 -2
  5. torchzero/modules/experimental/absoap.py +8 -2
  6. torchzero/modules/experimental/adadam.py +1 -1
  7. torchzero/modules/experimental/adamY.py +1 -1
  8. torchzero/modules/experimental/adasoap.py +2 -2
  9. torchzero/modules/experimental/algebraic_newton.py +1 -1
  10. torchzero/modules/experimental/curveball.py +1 -1
  11. torchzero/modules/experimental/gradmin.py +1 -1
  12. torchzero/modules/experimental/newton_solver.py +88 -0
  13. torchzero/modules/experimental/{dsoap.py → soapy.py} +4 -4
  14. torchzero/modules/experimental/spectral.py +5 -3
  15. torchzero/modules/experimental/structured_newton.py +111 -0
  16. torchzero/modules/experimental/subspace_preconditioners.py +16 -9
  17. torchzero/modules/optimizers/soap.py +1 -2
  18. torchzero/modules/projections/projection.py +27 -1
  19. torchzero/modules/quasi_newton/cg.py +9 -9
  20. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
  21. torchzero/modules/quasi_newton/lbfgs.py +4 -3
  22. torchzero/modules/quasi_newton/lsr1.py +7 -3
  23. torchzero/modules/quasi_newton/quasi_newton.py +18 -17
  24. torchzero/modules/second_order/__init__.py +1 -1
  25. torchzero/modules/second_order/newton.py +11 -6
  26. torchzero/modules/second_order/newton_cg.py +3 -3
  27. torchzero/modules/second_order/nystrom.py +6 -6
  28. torchzero/utils/linalg/benchmark.py +20 -0
  29. torchzero/utils/linalg/solve.py +15 -14
  30. {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/METADATA +2 -2
  31. {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/RECORD +34 -31
  32. {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/WHEEL +0 -0
  33. {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/licenses/LICENSE +0 -0
  34. {torchzero-0.3.6.dist-info → torchzero-0.3.9.dist-info}/top_level.txt +0 -0
tests/test_opts.py CHANGED
@@ -745,7 +745,7 @@ SSVM = Run(
745
745
  func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
746
  sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
747
  needs_closure=True,
748
- func='rosen', steps=50, loss=1e-12, merge_invariant=True,
748
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
749
749
  sphere_steps=10, sphere_loss=0,
750
750
  )
751
751
 
tests/test_tensorlist.py CHANGED
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
835
835
  expected = vec_equiv_func()
836
836
 
837
837
  if isinstance(result, bool): assert result == expected
838
- else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
838
+ else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
839
839
 
840
840
 
841
841
  def test_global_vector_norm(simple_tl: TensorList):
@@ -38,17 +38,18 @@ class Preconditioner(Transform):
38
38
 
39
39
 
40
40
  def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
41
- step = self.global_state.get('step', 0)
41
+ step = self.global_state.get('__step', 0)
42
42
  states = [self.state[p] for p in params]
43
43
  settings = [self.settings[p] for p in params]
44
44
  global_settings = settings[0]
45
45
  update_freq = global_settings['__update_freq']
46
46
 
47
47
  scale_first = global_settings['__scale_first']
48
- scale_factor = 0
48
+ scale_factor = 1
49
49
  if scale_first and step == 0:
50
50
  # initial step size guess from pytorch LBFGS
51
- scale_factor = TensorList(tensors).abs().sum()
51
+ scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
52
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
52
53
 
53
54
  # update preconditioner
54
55
  if step % update_freq == 0:
@@ -63,13 +64,13 @@ class Preconditioner(Transform):
63
64
 
64
65
  # scale initial step, when preconditioner might not have been applied
65
66
  if scale_first and step == 0:
66
- torch._foreach_div_(tensors, scale_factor)
67
+ torch._foreach_mul_(tensors, scale_factor)
67
68
 
68
- self.global_state['step'] = step + 1
69
+ self.global_state['__step'] = step + 1
69
70
  return tensors
70
71
 
71
72
  def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
72
- step = self.global_state.get('step', 0)
73
+ step = self.global_state.get('__step', 0)
73
74
  tensors_vec = torch.cat([t.ravel() for t in tensors])
74
75
  params_vec = torch.cat([p.ravel() for p in params])
75
76
  grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
@@ -80,10 +81,11 @@ class Preconditioner(Transform):
80
81
  update_freq = global_settings['__update_freq']
81
82
 
82
83
  scale_first = global_settings['__scale_first']
83
- scale_factor = 0
84
+ scale_factor = 1
84
85
  if scale_first and step == 0:
85
86
  # initial step size guess from pytorch LBFGS
86
- scale_factor = tensors_vec.abs().sum()
87
+ scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
88
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
87
89
 
88
90
  # update preconditioner
89
91
  if step % update_freq == 0:
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
99
101
 
100
102
  # scale initial step, when preconditioner might not have been applied
101
103
  if scale_first and step == 0:
102
- if scale_factor >= torch.finfo(tensors_vec.dtype).eps:
103
- tensors_vec /= scale_factor
104
+ tensors_vec *= scale_factor
104
105
 
105
106
  tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
106
- self.global_state['step'] = step + 1
107
+ self.global_state['__step'] = step + 1
107
108
  return tensors
108
109
 
109
110
  @torch.no_grad
@@ -3,7 +3,7 @@ from .adadam import Adadam
3
3
  from .adamY import AdamY
4
4
  from .adasoap import AdaSOAP
5
5
  from .curveball import CurveBall
6
- from .dsoap import DSOAP
6
+ from .soapy import SOAPY
7
7
  from .gradmin import GradMin
8
8
  from .reduce_outward_lr import ReduceOutwardLR
9
9
  from .spectral import SpectralPreconditioner
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
11
11
  HistorySubspacePreconditioning,
12
12
  RandomSubspacePreconditioning,
13
13
  )
14
- from .tropical_newton import TropicalNewton
14
+ from .tropical_newton import TropicalNewton
15
+ from .newton_solver import NewtonSolver
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
140
140
  class ABSOAP(Transform):
141
141
  """SOAP but with two extra letters included in its name in order to improve converence
142
142
 
143
+ so what you can do is choose what goes into what ,and that is supposed to be good.
144
+
143
145
  new args
144
146
 
145
147
  scale by s whether to scale gradient differences by parameter differences
146
148
 
147
149
  y_to_ema2 whether to use gradient differences for exponential moving average too
150
+
151
+ okay I changed these args into another ones
152
+
153
+ BASICALLY THIS IS FOR MY EXPERIMENTS
148
154
  """
149
155
  def __init__(
150
156
  self,
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
213
219
  if 'g_prev' not in state:
214
220
  state['p_prev'] = p.clone()
215
221
  state['g_prev'] = t.clone()
216
- updates.append(tensors[i].sign())
222
+ updates.append(tensors[i].clip(-0.1,0.1))
217
223
  continue
218
224
 
219
225
  p_prev = state['p_prev']
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
285
291
  state['Q'] = get_orthogonal_matrix(state['GG'])
286
292
 
287
293
  state['step'] = 0
288
- updates.append(tensors[i].sign())
294
+ updates.append(tensors[i].clip(-0.1,0.1))
289
295
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
290
296
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
291
297
 
@@ -50,7 +50,7 @@ def adadam_(
50
50
  return None
51
51
 
52
52
  class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner and a graceful name."""
53
+ """Adam with a diagonally preconditioned preconditioner."""
54
54
  def __init__(
55
55
  self,
56
56
  beta1: float = 0.9,
@@ -37,7 +37,7 @@ def adamy_(
37
37
  p_prev.copy_(p)
38
38
  g_prev.copy_(g)
39
39
 
40
- update = g.sign().lazy_mul_(alpha*0.1)
40
+ update = g.clip(-0.1,0.1).lazy_mul_(alpha)
41
41
  if params_ is None: return update
42
42
  params_.sub_(update)
43
43
  return None
@@ -218,9 +218,9 @@ class AdaSOAP(Transform):
218
218
  state['Q'] = get_orthogonal_matrix(GG_precond)
219
219
 
220
220
  state['step'] = 0
221
- updates.append(tensors[i].sign())
221
+ updates.append(tensors[i].clip(-0.1,0.1))
222
222
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
223
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
223
+ # that can mess with other modules scaling
224
224
 
225
225
  # Projecting gradients to the eigenbases of Shampoo's preconditioner
226
226
  # i.e. projecting to the eigenbases of matrices in state['GG']
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
71
71
 
72
72
 
73
73
  class AlgebraicNewton(Module):
74
- """newton in other algebras, not practical because solving linear system is very hard."""
74
+ """newton in other algebras, not that it works."""
75
75
  def __init__(
76
76
  self,
77
77
  reg: float | None = None,
@@ -13,7 +13,7 @@ def curveball(
13
13
  momentum: float | NumberList,
14
14
  precond_lr: float | NumberList,
15
15
  ):
16
- """returns z_, clone it!!!"""
16
+ """returns z_, clone it!!! (no just negate it)"""
17
17
  delta = Hz + tensors
18
18
  z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
19
19
  return z_
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
14
14
 
15
15
 
16
16
  class GradMin(Reformulation):
17
- """Reformulates the objective to minimize sum of gradient magnitudes via autograd.
17
+ """Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
18
18
 
19
19
  Args:
20
20
  loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
@@ -0,0 +1,88 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, overload
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, apply, Modular
7
+ from ...utils import TensorList, as_tensorlist
8
+ from ...utils.derivatives import hvp
9
+ from ..quasi_newton import LBFGS
10
+
11
+ class NewtonSolver(Module):
12
+ """Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
13
+ def __init__(
14
+ self,
15
+ solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
16
+ maxiter=None,
17
+ tol=1e-3,
18
+ reg: float = 0,
19
+ warm_start=True,
20
+ inner: Chainable | None = None,
21
+ ):
22
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
23
+ super().__init__(defaults,)
24
+
25
+ if inner is not None:
26
+ self.set_child('inner', inner)
27
+
28
+ @torch.no_grad
29
+ def step(self, vars):
30
+ params = TensorList(vars.params)
31
+ closure = vars.closure
32
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
33
+
34
+ settings = self.settings[params[0]]
35
+ solver_cls = settings['solver']
36
+ maxiter = settings['maxiter']
37
+ tol = settings['tol']
38
+ reg = settings['reg']
39
+ warm_start = settings['warm_start']
40
+
41
+ # ---------------------- Hessian vector product function --------------------- #
42
+ grad = vars.get_grad(create_graph=True)
43
+
44
+ def H_mm(x):
45
+ with torch.enable_grad():
46
+ Hvp = TensorList(hvp(params, grad, x, create_graph=True))
47
+ if reg != 0: Hvp = Hvp + (x*reg)
48
+ return Hvp
49
+
50
+ # -------------------------------- inner step -------------------------------- #
51
+ b = as_tensorlist(grad)
52
+ if 'inner' in self.children:
53
+ b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
54
+
55
+ # ---------------------------------- run cg ---------------------------------- #
56
+ x0 = None
57
+ if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
58
+ if x0 is None: x = b.zeros_like().requires_grad_(True)
59
+ else: x = x0.clone().requires_grad_(True)
60
+
61
+ solver = solver_cls(x)
62
+ def lstsq_closure(backward=True):
63
+ Hx = H_mm(x)
64
+ loss = (Hx-b).pow(2).global_mean()
65
+ if backward:
66
+ solver.zero_grad()
67
+ loss.backward(inputs=x)
68
+ return loss
69
+
70
+ if maxiter is None: maxiter = b.global_numel()
71
+ loss = None
72
+ initial_loss = lstsq_closure(False)
73
+ if initial_loss > tol:
74
+ for i in range(maxiter):
75
+ loss = solver.step(lstsq_closure)
76
+ assert loss is not None
77
+ if min(loss, loss/initial_loss) < tol: break
78
+
79
+ print(f'{loss = }')
80
+
81
+ if warm_start:
82
+ assert x0 is not None
83
+ x0.copy_(x)
84
+
85
+ vars.update = x.detach()
86
+ return vars
87
+
88
+
@@ -3,7 +3,7 @@ from operator import itemgetter
3
3
  import torch
4
4
 
5
5
  from ...core import Chainable, Transform, apply
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
6
+ from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
7
 
8
8
  @torch.no_grad
9
9
  def update_soap_covariances_(
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
135
135
 
136
136
  return final, exp_avg_sq
137
137
 
138
- class DSOAP(Transform):
138
+ class SOAPY(Transform):
139
139
  """SOAP but uses scaled gradient differences
140
140
 
141
141
  new args
@@ -195,7 +195,7 @@ class DSOAP(Transform):
195
195
  if 'g_prev' not in state:
196
196
  state['p_prev'] = p.clone()
197
197
  state['g_prev'] = t.clone()
198
- updates.append(tensors[i].sign())
198
+ updates.append(tensors[i].clip(-0.1,0.1))
199
199
  continue
200
200
 
201
201
  p_prev = state['p_prev']
@@ -228,7 +228,7 @@ class DSOAP(Transform):
228
228
  state['Q'] = get_orthogonal_matrix(state['GG'])
229
229
 
230
230
  state['step'] = 0
231
- updates.append(tensors[i].sign())
231
+ updates.append(tensors[i].clip(-0.1,0.1))
232
232
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
233
233
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
234
234
 
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
194
194
  order (int, optional):
195
195
  whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
196
196
  solver (str, optional): what to use for whitening. Defaults to 'svd'.
197
- U_beta (float | None, optional): beta for U (probably a bad idea). Defaults to None.
198
- S_beta (float | None, optional): beta for S (probably a bad idea). Defaults to None.
197
+ A_beta (float | None, optional):
198
+ beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
+ B_beta (float | None, optional):
200
+ beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
201
  interval (int, optional): How often to update history. Defaults to 1 (every step).
200
202
  concat_params (bool, optional):
201
203
  whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
275
277
  A = state.get('A', None)
276
278
  if A is None:
277
279
  # make a conservative step to avoid issues due to different GD scaling
278
- return tensor.div_(max(1, tensor.abs().sum())) # pyright:ignore[reportArgumentType]
280
+ return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
279
281
 
280
282
  B = state['B']
281
283
  update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
@@ -0,0 +1,111 @@
1
+ # idea https://arxiv.org/pdf/2212.09841
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Chainable, Module, apply
10
+ from ...utils import TensorList, vec_to_tensors
11
+ from ...utils.derivatives import (
12
+ hessian_list_to_mat,
13
+ hessian_mat,
14
+ hvp,
15
+ hvp_fd_central,
16
+ hvp_fd_forward,
17
+ jacobian_and_hessian_wrt,
18
+ )
19
+
20
+
21
+ class StructuredNewton(Module):
22
+ """TODO
23
+ Args:
24
+ structure (str, optional): structure.
25
+ reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
26
+ hvp_method (str):
27
+ how to calculate hvp_method. Defaults to "autograd".
28
+ inner (Chainable | None, optional): inner modules. Defaults to None.
29
+
30
+ """
31
+ def __init__(
32
+ self,
33
+ structure: Literal[
34
+ "diagonal",
35
+ "diagonal1",
36
+ "diagonal_abs",
37
+ "tridiagonal",
38
+ "circulant",
39
+ "toeplitz",
40
+ "toeplitz_like",
41
+ "hankel",
42
+ "rank1",
43
+ "rank2", # any rank
44
+ ]
45
+ | str = "diagonal",
46
+ reg: float = 1e-6,
47
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
+ h: float = 1e-3,
49
+ inner: Chainable | None = None,
50
+ ):
51
+ defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
52
+ super().__init__(defaults)
53
+
54
+ if inner is not None:
55
+ self.set_child('inner', inner)
56
+
57
+ @torch.no_grad
58
+ def step(self, vars):
59
+ params = TensorList(vars.params)
60
+ closure = vars.closure
61
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
62
+
63
+ settings = self.settings[params[0]]
64
+ reg = settings['reg']
65
+ hvp_method = settings['hvp_method']
66
+ structure = settings['structure']
67
+ h = settings['h']
68
+
69
+ # ------------------------ calculate grad and hessian ------------------------ #
70
+ if hvp_method == 'autograd':
71
+ grad = vars.get_grad(create_graph=True)
72
+ def Hvp_fn1(x):
73
+ return hvp(params, grad, x, retain_graph=True)
74
+ Hvp_fn = Hvp_fn1
75
+
76
+ elif hvp_method == 'forward':
77
+ grad = vars.get_grad()
78
+ def Hvp_fn2(x):
79
+ return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
80
+ Hvp_fn = Hvp_fn2
81
+
82
+ elif hvp_method == 'central':
83
+ grad = vars.get_grad()
84
+ def Hvp_fn3(x):
85
+ return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
86
+ Hvp_fn = Hvp_fn3
87
+
88
+ else: raise ValueError(hvp_method)
89
+
90
+ # -------------------------------- inner step -------------------------------- #
91
+ update = vars.get_update()
92
+ if 'inner' in self.children:
93
+ update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
94
+
95
+ # hessian
96
+ if structure.startswith('diagonal'):
97
+ H = Hvp_fn([torch.ones_like(p) for p in params])
98
+ if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
99
+ if structure == 'diagonal_abs': torch._foreach_abs_(H)
100
+ torch._foreach_add_(H, reg)
101
+ torch._foreach_div_(update, H)
102
+ vars.update = update
103
+ return vars
104
+
105
+ # hessian
106
+ raise NotImplementedError(structure)
107
+
108
+
109
+
110
+
111
+
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
38
38
  return basis @ update_projected # d
39
39
 
40
40
  class RandomSubspacePreconditioning(Transform):
41
- """full matrix rmsprop in random subspace"""
42
- def __init__(self, k: int, beta: float | None = 0.99):
43
- defaults = dict(k=k, beta=beta)
41
+ """full matrix rmsprop in random slowly changing subspace"""
42
+ def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
+ defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
44
  super().__init__(defaults, uses_grad=False)
45
45
 
46
+ if inner is not None: self.set_child('inner', inner)
47
+
46
48
  def transform(self, tensors, params, grads, vars):
47
49
  settings = self.settings[params[0]]
48
50
  g = torch.cat([t.view(-1) for t in tensors])
49
51
  k = settings['k']
50
52
  beta = settings['beta']
53
+ basis_beta = settings['basis_beta']
51
54
 
52
55
  if 'basis' not in self.global_state:
53
56
  self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
56
59
  basis = self.global_state['basis']
57
60
  accumulator = self.global_state['accumulator']
58
61
 
62
+ if basis_beta is not None:
63
+ basis.lerp_(torch.randn_like(basis), 1-basis_beta)
64
+
59
65
  update_subspace_preconditioner_(g, basis, accumulator, beta)
66
+
67
+ if 'inner' in self.children:
68
+ tensors = apply(self.children['inner'], tensors, params, grads, vars)
69
+ g = torch.cat([t.view(-1) for t in tensors])
70
+
60
71
  try:
61
72
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
62
73
  except torch.linalg.LinAlgError:
63
- denom = g.abs().sum()
64
- if denom <= 1e-10: denom = torch.ones_like(denom)
65
- preconditioned = g / g.abs().sum()
74
+ preconditioned = g.clip(-0.1, 0.1)
66
75
  vec_to_tensors_(preconditioned, tensors)
67
76
 
68
77
  return tensors
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
119
128
  try:
120
129
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
121
130
  except torch.linalg.LinAlgError:
122
- denom = g.abs().sum()
123
- if denom <= 1e-10: denom = torch.ones_like(denom)
124
- preconditioned = g / g.abs().sum()
131
+ preconditioned = g.clip(-0.1,0.1)
125
132
  vec_to_tensors_(preconditioned, tensors)
126
133
 
127
134
  return tensors
@@ -222,8 +222,7 @@ class SOAP(Transform):
222
222
  state['Q'] = get_orthogonal_matrix(state['GG'])
223
223
 
224
224
  state['step'] = 0
225
- updates.append(tensors[i].sign().div_(10))
226
- # updates.append(tensors[i] / tensors[i].abs().sum())
225
+ updates.append(tensors[i].clip(-0.1, 0.1))
227
226
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
228
227
  # I use scaled update instead as to not mess up with next modules.
229
228
 
@@ -1,4 +1,5 @@
1
1
  import math
2
+ from functools import partial
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Iterable
4
5
  from typing import Any, Literal
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
33
34
 
34
35
  return projected_closure
35
36
 
37
+ def _projected_get_grad_override(
38
+ retain_graph: bool | None = None,
39
+ create_graph: bool = False,
40
+ projection: Any = ...,
41
+ unprojected_vars: Any = ...,
42
+ self: Any = ...,
43
+ ):
44
+ assert isinstance(projection, Projection)
45
+ assert isinstance(unprojected_vars, Vars)
46
+ assert isinstance(self, Vars)
47
+
48
+ if self.grad is not None: return self.grad
49
+ grads = unprojected_vars.get_grad(retain_graph, create_graph)
50
+ projected_grads = list(projection.project(grads, self, current='grads'))
51
+ self.grad = projected_grads
52
+ for p, g in zip(self.params, projected_grads):
53
+ p.grad = g
54
+ return self.grad
55
+
36
56
 
37
57
  class Projection(Module, ABC):
38
58
  """
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
137
157
 
138
158
  # step
139
159
  projected_vars.params = self._projected_params
160
+ projected_vars.get_grad = partial(
161
+ _projected_get_grad_override,
162
+ projection=self,
163
+ unprojected_vars=vars,
164
+ self=projected_vars,
165
+ )
140
166
  projected_vars = self.children['modules'].step(projected_vars)
141
167
 
142
168
  # empty fake params storage
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
149
175
  unprojected_vars = projected_vars.clone(clone_update=False)
150
176
  unprojected_vars.closure = vars.closure
151
177
  unprojected_vars.params = vars.params
152
- if unprojected_vars.grad is None: unprojected_vars.grad = vars.grad
178
+ unprojected_vars.grad = vars.grad
153
179
 
154
180
  if self._project_update:
155
181
  assert projected_vars.update is not None
@@ -64,7 +64,7 @@ class ConguateGradientBase(Transform, ABC):
64
64
  # ------------------------------- Polak-Ribière ------------------------------ #
65
65
  def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
66
66
  denom = prev_g.dot(prev_g)
67
- if denom == 0: return 0
67
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
68
68
  return g.dot(g - prev_g) / denom
69
69
 
70
70
  class PolakRibiere(ConguateGradientBase):
@@ -76,8 +76,8 @@ class PolakRibiere(ConguateGradientBase):
76
76
  return polak_ribiere_beta(g, prev_g)
77
77
 
78
78
  # ------------------------------ Fletcher–Reeves ----------------------------- #
79
- def fletcher_reeves_beta(gg, prev_gg):
80
- if prev_gg == 0: return 0
79
+ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
80
+ if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
81
81
  return gg / prev_gg
82
82
 
83
83
  class FletcherReeves(ConguateGradientBase):
@@ -98,7 +98,7 @@ class FletcherReeves(ConguateGradientBase):
98
98
  def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
99
99
  grad_diff = g - prev_g
100
100
  denom = prev_d.dot(grad_diff)
101
- if denom == 0: return 0
101
+ if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
102
102
  return (g.dot(grad_diff) / denom).neg()
103
103
 
104
104
 
@@ -114,7 +114,7 @@ class HestenesStiefel(ConguateGradientBase):
114
114
  # --------------------------------- Dai–Yuan --------------------------------- #
115
115
  def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
116
116
  denom = prev_d.dot(g - prev_g)
117
- if denom == 0: return 0
117
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
118
118
  return (g.dot(g) / denom).neg()
119
119
 
120
120
  class DaiYuan(ConguateGradientBase):
@@ -129,7 +129,7 @@ class DaiYuan(ConguateGradientBase):
129
129
  # -------------------------------- Liu-Storey -------------------------------- #
130
130
  def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
131
131
  denom = prev_g.dot(prev_d)
132
- if denom == 0: return 0
132
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
133
133
  return g.dot(g - prev_g) / denom
134
134
 
135
135
  class LiuStorey(ConguateGradientBase):
@@ -159,7 +159,7 @@ class ConjugateDescent(Transform):
159
159
  self.global_state['denom'] = torch.tensor(0.).to(g[0])
160
160
 
161
161
  prev_gd = self.global_state.get('prev_gd', 0)
162
- if prev_gd == 0: beta = 0
162
+ if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
163
163
  else: beta = g.dot(g) / prev_gd
164
164
 
165
165
  # inner step
@@ -176,7 +176,7 @@ class ConjugateDescent(Transform):
176
176
  def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
177
177
  g_diff = g - prev_g
178
178
  denom = prev_d.dot(g_diff)
179
- if denom == 0: return 0
179
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
180
180
 
181
181
  term1 = 1/denom
182
182
  # term2
@@ -198,7 +198,7 @@ class HagerZhang(ConguateGradientBase):
198
198
  def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
199
199
  grad_diff = g - prev_g
200
200
  denom = prev_d.dot(grad_diff)
201
- if denom == 0: return 0
201
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
202
202
 
203
203
  # Dai-Yuan
204
204
  dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
@@ -37,10 +37,11 @@ def lbfgs(
37
37
  z_tfm: Any,
38
38
  ):
39
39
  if len(s_history) == 0 or y_k is None or ys_k is None:
40
- # dir = params.grad.sign() # may work fine
41
40
 
42
- # initial step size guess taken from pytorch L-BFGS
43
- return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
41
+ # initial step size guess modified from pytorch L-BFGS
42
+ scale = 1 / tensors_.abs().global_sum()
43
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
44
+ return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
44
45
 
45
46
  else:
46
47
  # 1st loop
@@ -36,10 +36,11 @@ def lbfgs(
36
36
  step: int,
37
37
  ):
38
38
  if len(s_history) == 0 or y_k is None or ys_k is None:
39
- # dir = params.grad.sign() # may work fine
40
39
 
41
- # initial step size guess taken from pytorch L-BFGS
42
- return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
40
+ # initial step size guess modified from pytorch L-BFGS
41
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
42
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
43
+ return tensors_.mul_(scale_factor)
43
44
 
44
45
  else:
45
46
  # 1st loop
@@ -17,8 +17,9 @@ def lsr1_(
17
17
  ):
18
18
  if step == 0 or not s_history:
19
19
  # initial step size guess from pytorch
20
- tensors_.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
21
- return tensors_
20
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
21
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
22
+ return tensors_.mul_(scale_factor)
22
23
 
23
24
  m = len(s_history)
24
25
 
@@ -64,7 +65,10 @@ def lsr1_(
64
65
  Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
65
66
 
66
67
  if scale_second and step == 1:
67
- Hx.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
68
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
69
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
70
+ Hx.mul_(scale_factor)
71
+
68
72
  return Hx
69
73
 
70
74
 
@@ -68,6 +68,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
68
68
  M_key = 'H' if inverse else 'B'
69
69
  M = state.get(M_key, None)
70
70
  step = state.get('step', 0)
71
+ state['step'] = step + 1
71
72
  init_scale = settings['init_scale']
72
73
  tol = settings['tol']
73
74
  tol_reset = settings['tol_reset']
@@ -91,13 +92,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
91
92
  state['p_prev'].copy_(p)
92
93
  state['g_prev'].copy_(g)
93
94
 
94
-
95
- if reset_interval is not None and step % reset_interval == 0:
95
+ if reset_interval is not None and step != 0 and step % reset_interval == 0:
96
96
  self._reset_M_(M, s, y, inverse, init_scale)
97
97
  return
98
98
 
99
99
  # tolerance on gradient difference to avoid exploding after converging
100
- if y.abs().max() <= tol:
100
+ elif y.abs().max() <= tol:
101
101
  # reset history
102
102
  if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
103
103
  return
@@ -119,11 +119,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
119
119
 
120
120
  @torch.no_grad
121
121
  def apply_tensor(self, tensor, param, grad, state, settings):
122
- step = state['step'] = state.get('step', 0) + 1
122
+ step = state.get('step', 0)
123
123
 
124
124
  if settings['scale_second'] and step == 2:
125
- s = max(1, tensor.abs().sum()) # pyright:ignore[reportArgumentType]
126
- if s < settings['tol']: tensor = tensor/s
125
+ scale_factor = 1 / tensor.abs().sum().clip(min=1)
126
+ scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
127
+ tensor = tensor * scale_factor
127
128
 
128
129
  inverse = settings['inverse']
129
130
  if inverse:
@@ -135,7 +136,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
135
136
  return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
136
137
 
137
138
  # to avoid typing all arguments for each method
138
- class QuasiNewtonH(HessianUpdateStrategy):
139
+ class HUpdateStrategy(HessianUpdateStrategy):
139
140
  def __init__(
140
141
  self,
141
142
  init_scale: float | Literal["auto"] = "auto",
@@ -174,7 +175,7 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
174
175
  H += term1.sub_(term2)
175
176
  return H
176
177
 
177
- class BFGS(QuasiNewtonH):
178
+ class BFGS(HUpdateStrategy):
178
179
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
179
180
  return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
180
181
 
@@ -193,7 +194,7 @@ def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
193
194
  H += torch.outer(z, z).div_(denom)
194
195
  return H
195
196
 
196
- class SR1(QuasiNewtonH):
197
+ class SR1(HUpdateStrategy):
197
198
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
198
199
  return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
199
200
 
@@ -213,7 +214,7 @@ def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
213
214
  H += term1.sub_(term2)
214
215
  return H
215
216
 
216
- class DFP(QuasiNewtonH):
217
+ class DFP(HUpdateStrategy):
217
218
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
218
219
  return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
219
220
 
@@ -254,19 +255,19 @@ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
254
255
  H -= num/denom
255
256
  return H
256
257
 
257
- class BroydenGood(QuasiNewtonH):
258
+ class BroydenGood(HUpdateStrategy):
258
259
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
259
260
  return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
260
261
 
261
- class BroydenBad(QuasiNewtonH):
262
+ class BroydenBad(HUpdateStrategy):
262
263
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
263
264
  return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
264
265
 
265
- class Greenstadt1(QuasiNewtonH):
266
+ class Greenstadt1(HUpdateStrategy):
266
267
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
267
268
  return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
268
269
 
269
- class Greenstadt2(QuasiNewtonH):
270
+ class Greenstadt2(HUpdateStrategy):
270
271
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
271
272
  return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
272
273
 
@@ -287,7 +288,7 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
287
288
  H[:, j] += num.squeeze() / denom
288
289
  return H
289
290
 
290
- class ColumnUpdatingMethod(QuasiNewtonH):
291
+ class ColumnUpdatingMethod(HUpdateStrategy):
291
292
  """Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
292
293
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
293
294
  return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
@@ -307,7 +308,7 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
307
308
  H -= num/denom
308
309
  return H, R
309
310
 
310
- class ThomasOptimalMethod(QuasiNewtonH):
311
+ class ThomasOptimalMethod(HUpdateStrategy):
311
312
  """Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
312
313
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
313
314
  if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
@@ -364,7 +365,7 @@ def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
364
365
  H += num.div_(sy)
365
366
  return H
366
367
 
367
- class Pearson2(QuasiNewtonH):
368
+ class Pearson2(HUpdateStrategy):
368
369
  """finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
369
370
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
370
371
  return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
@@ -1,3 +1,3 @@
1
1
  from .newton import Newton
2
2
  from .newton_cg import NewtonCG
3
- from .nystrom import NystromSketchAndSolve, NystromPCG
3
+ from .nystrom import NystromSketchAndSolve, NystromPCG
@@ -1,14 +1,18 @@
1
1
  import warnings
2
+ from collections.abc import Callable
2
3
  from functools import partial
3
4
  from typing import Literal
4
- from collections.abc import Callable
5
+
5
6
  import torch
6
7
 
7
- from ...core import Chainable, apply, Module
8
- from ...utils import vec_to_tensors, TensorList
8
+ from ...core import Chainable, Module, apply
9
+ from ...utils import TensorList, vec_to_tensors
9
10
  from ...utils.derivatives import (
10
11
  hessian_list_to_mat,
11
12
  hessian_mat,
13
+ hvp,
14
+ hvp_fd_central,
15
+ hvp_fd_forward,
12
16
  jacobian_and_hessian_wrt,
13
17
  )
14
18
 
@@ -117,9 +121,10 @@ class Newton(Module):
117
121
  raise ValueError(hessian_method)
118
122
 
119
123
  # -------------------------------- inner step -------------------------------- #
124
+ update = vars.get_update()
120
125
  if 'inner' in self.children:
121
- g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
122
- g = torch.cat([t.view(-1) for t in g_list])
126
+ update = apply(self.children['inner'], update, params=params, grads=list(g_list), vars=vars)
127
+ g = torch.cat([t.view(-1) for t in update])
123
128
 
124
129
  # ------------------------------- regulazition ------------------------------- #
125
130
  if eig_reg: H = eig_tikhonov_(H, reg)
@@ -139,4 +144,4 @@ class Newton(Module):
139
144
  if update is None: update = least_squares_solve(H, g)
140
145
 
141
146
  vars.update = vec_to_tensors(update, params)
142
- return vars
147
+ return vars
@@ -66,9 +66,9 @@ class NewtonCG(Module):
66
66
 
67
67
 
68
68
  # -------------------------------- inner step -------------------------------- #
69
- b = grad
69
+ b = vars.get_update()
70
70
  if 'inner' in self.children:
71
- b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
71
+ b = as_tensorlist(apply(self.children['inner'], b, params=params, grads=grad, vars=vars))
72
72
 
73
73
  # ---------------------------------- run cg ---------------------------------- #
74
74
  x0 = None
@@ -76,7 +76,7 @@ class NewtonCG(Module):
76
76
  x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
77
77
  if warm_start:
78
78
  assert x0 is not None
79
- x0.set_(x)
79
+ x0.copy_(x)
80
80
 
81
81
  vars.update = x
82
82
  return vars
@@ -15,7 +15,7 @@ class NystromSketchAndSolve(Module):
15
15
  rank: int,
16
16
  reg: float = 1e-3,
17
17
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
- h=1e-3,
18
+ h=1e-2,
19
19
  inner: Chainable | None = None,
20
20
  seed: int | None = None,
21
21
  ):
@@ -74,9 +74,9 @@ class NystromSketchAndSolve(Module):
74
74
 
75
75
 
76
76
  # -------------------------------- inner step -------------------------------- #
77
- b = grad
77
+ b = vars.get_update()
78
78
  if 'inner' in self.children:
79
- b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
79
+ b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
80
80
 
81
81
  # ------------------------------ sketch&n&solve ------------------------------ #
82
82
  x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
@@ -93,7 +93,7 @@ class NystromPCG(Module):
93
93
  tol=1e-3,
94
94
  reg: float = 1e-6,
95
95
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
96
- h=1e-3,
96
+ h=1e-2,
97
97
  inner: Chainable | None = None,
98
98
  seed: int | None = None,
99
99
  ):
@@ -156,9 +156,9 @@ class NystromPCG(Module):
156
156
 
157
157
 
158
158
  # -------------------------------- inner step -------------------------------- #
159
- b = grad
159
+ b = vars.get_update()
160
160
  if 'inner' in self.children:
161
- b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
161
+ b = apply(self.children['inner'], b, params=params, grads=grad, vars=vars)
162
162
 
163
163
  # ------------------------------ sketch&n&solve ------------------------------ #
164
164
  x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
@@ -0,0 +1,20 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+
5
+
6
+ def benchmark_solver(
7
+ A: torch.Tensor | Callable[[torch.Tensor], torch.Tensor],
8
+ b: torch.Tensor,
9
+ solver: Callable[[Callable[[torch.Tensor], torch.Tensor], torch.Tensor]]
10
+ ):
11
+ residuals = []
12
+ def A_mm(x):
13
+ if callable(A): Ax = A(x)
14
+ else: Ax = A@x
15
+ residuals.append(torch.linalg.vector_norm(Ax-b)) # pylint:disable=not-callable
16
+ return Ax
17
+
18
+ solver(A_mm, b)
19
+ return residuals
20
+
@@ -8,27 +8,27 @@ from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_nume
8
8
  def cg(
9
9
  A_mm: Callable[[torch.Tensor], torch.Tensor],
10
10
  b: torch.Tensor,
11
- x0_: torch.Tensor | None,
12
- tol: float | None,
13
- maxiter: int | None,
11
+ x0_: torch.Tensor | None = None,
12
+ tol: float | None = 1e-4,
13
+ maxiter: int | None = None,
14
14
  reg: float = 0,
15
15
  ) -> torch.Tensor: ...
16
16
  @overload
17
17
  def cg(
18
18
  A_mm: Callable[[TensorList], TensorList],
19
19
  b: TensorList,
20
- x0_: TensorList | None,
21
- tol: float | None,
22
- maxiter: int | None,
20
+ x0_: TensorList | None = None,
21
+ tol: float | None = 1e-4,
22
+ maxiter: int | None = None,
23
23
  reg: float | list[float] | tuple[float] = 0,
24
24
  ) -> TensorList: ...
25
25
 
26
26
  def cg(
27
27
  A_mm: Callable,
28
28
  b: torch.Tensor | TensorList,
29
- x0_: torch.Tensor | TensorList | None,
30
- tol: float | None,
31
- maxiter: int | None,
29
+ x0_: torch.Tensor | TensorList | None = None,
30
+ tol: float | None = 1e-4,
31
+ maxiter: int | None = None,
32
32
  reg: float | list[float] | tuple[float] = 0,
33
33
  ):
34
34
  def A_mm_reg(x): # A_mm with regularization
@@ -90,7 +90,7 @@ def nystrom_sketch_and_solve(
90
90
  A_mm: Callable[[torch.Tensor], torch.Tensor],
91
91
  b: torch.Tensor,
92
92
  rank: int,
93
- reg: float,
93
+ reg: float = 1e-3,
94
94
  generator=None,
95
95
  ) -> torch.Tensor:
96
96
  U, lambd = nystrom_approximation(
@@ -116,10 +116,10 @@ def nystrom_pcg(
116
116
  A_mm: Callable[[torch.Tensor], torch.Tensor],
117
117
  b: torch.Tensor,
118
118
  sketch_size: int,
119
- reg: float,
120
- x0_: torch.Tensor | None,
121
- tol: float | None,
122
- maxiter: int | None,
119
+ reg: float = 1e-6,
120
+ x0_: torch.Tensor | None = None,
121
+ tol: float | None = 1e-4,
122
+ maxiter: int | None = None,
123
123
  generator=None,
124
124
  ) -> torch.Tensor:
125
125
  U, lambd = nystrom_approximation(
@@ -166,3 +166,4 @@ def nystrom_pcg(
166
166
  z = P_inv @ residual
167
167
  beta = residual.dot(z) / rz
168
168
  p = z + p*beta
169
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.6
3
+ Version: 0.3.9
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -156,7 +156,7 @@ for epoch in range(100):
156
156
  * `Newton`: Classic Newton's method.
157
157
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
158
158
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
- * `NystromPCG`: NewtonCG with Nyström preconditioning (my current recommendation).
159
+ * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
160
160
 
161
161
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
162
  * `LBFGS`: Limited-memory BFGS.
@@ -1,14 +1,14 @@
1
1
  docs/source/conf.py,sha256=jd80ZT2IdCx7nlQrpOTJL8UhGBNm6KYyXlpp0jmRiAw,1849
2
2
  tests/test_identical.py,sha256=NZ7A8Rm1U9Q16d-cG2G_wccpPtNALyoKYJt9qMownMc,11568
3
3
  tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
4
- tests/test_opts.py,sha256=oDZVFr9AE9ZhyR-sImSgNzQsbPsUtJLzuLd1Nxgkp1w,40850
5
- tests/test_tensorlist.py,sha256=VWX9wYdfkG-0Y8I0wWPp56ZJM0mBNPvS_SC3irmcYcs,72427
4
+ tests/test_opts.py,sha256=TZVaCv2ZLdHSkL6snTEkqhTMHqlcO55L-c56k6Hh4xc,40850
5
+ tests/test_tensorlist.py,sha256=Djpr5C0T5d_gz-j-P-bpo_X51DC4twbtT9c-xDSFbP0,72438
6
6
  tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
7
7
  tests/test_vars.py,sha256=3p9dsHk7SJpMd-WRD0ziBNq5FEHRBJGSxbMLD8ES4J0,6815
8
8
  torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
9
9
  torchzero/core/__init__.py,sha256=2JRyeGZprTexAeEPQOIl9fLFGBwzvya-AwKyt7XAmGQ,210
10
10
  torchzero/core/module.py,sha256=Razw3c71Kfegznm0vQxsii1KuTUCPBC9UGyq2v-KX4M,27568
11
- torchzero/core/preconditioner.py,sha256=rMYusKbaypm5K0Ii9VdjKhxi2YWNQbBk9f6AV_MJulY,6191
11
+ torchzero/core/preconditioner.py,sha256=n9oh7kZdt1kU3Wh472lnvLrsXwhR5Wqe6lIp7JuAJ_I,6336
12
12
  torchzero/core/transform.py,sha256=ajNJcX45ds-_lc5CqxgLfEFGil6_BYLerB0WvoTi8rM,10303
13
13
  torchzero/modules/__init__.py,sha256=BDeyuSd2s1WFUUXIo3tGTNp4aYp4A2B94cydpPW24nY,332
14
14
  torchzero/modules/functional.py,sha256=HXNzmPe7LsPadryEm7zrcEKqGej16QDwSgBkbEvggFM,6492
@@ -16,18 +16,20 @@ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLB
16
16
  torchzero/modules/clipping/clipping.py,sha256=I-5utyrqdKtF5yaH-9m2F3UqdfpPmA2bSSFUAZ_d60Q,12544
17
17
  torchzero/modules/clipping/ema_clipping.py,sha256=pLeNuEBLpJ74io2sHn_ZVYaQ6ydEfhpVfVEX2bFttd0,5947
18
18
  torchzero/modules/clipping/growth_clipping.py,sha256=OD-kdia2Rn-DvYlYV6EZlGPDVTh9tj-W9mpiZPc3hOQ,6772
19
- torchzero/modules/experimental/__init__.py,sha256=sJ6URgX35P3zJ2ugBKgAcwBWmdBmAPDW3vXHQ0sK-ro,443
20
- torchzero/modules/experimental/absoap.py,sha256=XUHr5SeLdhLW2kMvWea5xAqZeuJBDQoO4zprDxs4bgU,13317
21
- torchzero/modules/experimental/adadam.py,sha256=W7rRXYJ9tGrzqD_FdFX00HBLuWOEr2tHtfshf6lDFYE,4049
22
- torchzero/modules/experimental/adamY.py,sha256=FoSn-qMI5_BdqZH10WGKkl-zYTPESBdGZ9lfhyqnbB0,4591
23
- torchzero/modules/experimental/adasoap.py,sha256=07gPdEdBIKtmdmSzTGtTO0c2ZkS_otVLufQ76okBjHY,11239
24
- torchzero/modules/experimental/algebraic_newton.py,sha256=_XFYR6bdHWgA5bozxc9AJYteBIAnHrSLgo_bSaZ13eg,5193
25
- torchzero/modules/experimental/curveball.py,sha256=Nw9jtSp5QNj7-FN3qshjYEDHc68LwRLha-Co78mfR5w,3242
26
- torchzero/modules/experimental/dsoap.py,sha256=BEZDw3_n5VDhu7VLgkoSN4rI9JeBdGoO9gFZfqsh74M,10983
27
- torchzero/modules/experimental/gradmin.py,sha256=55dpBDNyrYJusluFhw-v1BXuj1UxER7pNEPTtwYKD4E,3648
19
+ torchzero/modules/experimental/__init__.py,sha256=fEPDYDl7qhaFoferDRmG3ehwuqSvx4Vt2uOz0Y7h4to,483
20
+ torchzero/modules/experimental/absoap.py,sha256=Z4MS4pDPSQ9IaTk8g57OfrsWcYVOT72x533KKtn2Zxk,13512
21
+ torchzero/modules/experimental/adadam.py,sha256=OAPF1-NUbg79V3QOTYzsQlRC97C7XHj5boOLDqLz3PE,4029
22
+ torchzero/modules/experimental/adamY.py,sha256=g1pAHwgdyDdKvObZ67lCSc36L99tl5jlQgOr4lMJCDo,4595
23
+ torchzero/modules/experimental/adasoap.py,sha256=JdV6rB9xfqL3vbHpZCLmkJZKRObZ1nVoEmabtIeVT3E,11195
24
+ torchzero/modules/experimental/algebraic_newton.py,sha256=sq5ZD_j_EtlxIjNnS0rKKwTSG_JuwsZOg9ZMMQTuQm0,5154
25
+ torchzero/modules/experimental/curveball.py,sha256=Uk30uLEztTHD5IUJLJm9Nn3x31DF9kQHmeLFhc065us,3262
26
+ torchzero/modules/experimental/gradmin.py,sha256=iJmEvDEdVdck0C-94pY3iGxnIoNv6Fu6vj3f7lS6aQM,3686
27
+ torchzero/modules/experimental/newton_solver.py,sha256=iGI2LHLaZd2ovpbq1Vogs76os0zWG7VwM7nUz8RzxVg,3071
28
28
  torchzero/modules/experimental/reduce_outward_lr.py,sha256=kjtRwepBGBca77ToM-lw3b8ywptMtmSdC_jQfjJAwlY,1184
29
- torchzero/modules/experimental/spectral.py,sha256=D3_nCI8teFirCdnnLprNnZ3G1gsOB6RUBWCeDbwi7P0,12043
30
- torchzero/modules/experimental/subspace_preconditioners.py,sha256=4SRJOyTG-fJCGunHR62aRrzw3qFmeI6fRQAYHIadhWw,4682
29
+ torchzero/modules/experimental/soapy.py,sha256=Ishd2Jj6BbhjrLyC48zf-cjMmA1kJb_uKXESQBIML_s,10990
30
+ torchzero/modules/experimental/spectral.py,sha256=8_n208V2yPY3z5pCym-FvwO7DGFhozNgWlpIBtQSdrI,12139
31
+ torchzero/modules/experimental/structured_newton.py,sha256=uWczR-uAXHaFwf0mlOThv2sLG0irH6Gz1hKlGHtPAj4,3386
32
+ torchzero/modules/experimental/subspace_preconditioners.py,sha256=WnHpga7Kx4-N2xU5vP3uUHRER70ymyNJCWbSx2zXWOk,4976
31
33
  torchzero/modules/experimental/tropical_newton.py,sha256=uq66ouhgrgc8iYGozDQ3_rtbubj8rKRwb1jfcdnlpHg,4903
32
34
  torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
33
35
  torchzero/modules/grad_approximation/fdm.py,sha256=2PNNBIMup1xlOwLFAwAS3xAVd-7GGVyerMeKH1ug9LQ,3591
@@ -70,26 +72,26 @@ torchzero/modules/optimizers/orthograd.py,sha256=5BLnNJTYuGUClHmlxaXZ1jNvBR4zSFD
70
72
  torchzero/modules/optimizers/rmsprop.py,sha256=d10Y9Ck-391tVysO3xMHg3g2Pe0UEZplgebEyDYi3Z4,4333
71
73
  torchzero/modules/optimizers/rprop.py,sha256=n4k5-9F3ppH0Xl-4l4vNXfqVf2r67vMPCkstUaQKPLw,10974
72
74
  torchzero/modules/optimizers/shampoo.py,sha256=AHHV6d71DqKDPCg52ShWIPIRSGtWkMc1v1XwXgDG3qY,8606
73
- torchzero/modules/optimizers/soap.py,sha256=HL1YrfiEiRMh6aW9D5UEZXBjo3yMTqnpKPHXVD8fOa8,11590
75
+ torchzero/modules/optimizers/soap.py,sha256=Kf2BAtIf2QY1V2ZJcUjRLcp2WfIVLd3mNclnaT3Nmds,11520
74
76
  torchzero/modules/optimizers/sophia_h.py,sha256=8pSlYVm66xWplzdP8MX3MCTzzIYHsxGzDEXJKA03Zgg,4279
75
77
  torchzero/modules/projections/__init__.py,sha256=OCxlh_-Tx-xpl31X03CeFJH9XveH563oEsWc8rUvX0A,196
76
78
  torchzero/modules/projections/dct.py,sha256=wxaEV6dTNiOqW_n2UHX0De6mMXTKDXK6UNcMNI4Rogk,2373
77
79
  torchzero/modules/projections/fft.py,sha256=OpCcEM1-A2dgk1umwRsBsvK7ObiHtsBKlkkcw0IX83Q,2961
78
80
  torchzero/modules/projections/galore.py,sha256=c9CZ0kHxpKEoyfc_lnmeHOkNp55jCppb7onN5YmWnN8,242
79
- torchzero/modules/projections/projection.py,sha256=tvUBZ4XGY1GkOg6jrKS7FvpIpjUc2FJL_SMRpoROT1E,9330
81
+ torchzero/modules/projections/projection.py,sha256=aYufSD3ftRUqVScPmqxwEFgP1P8ioxM8z9eyzaL7d10,10147
80
82
  torchzero/modules/projections/structural.py,sha256=QaCGHmzHCXj46sM-XZ5XlYU9BnuRKI2ReR3LE8y2R4g,5740
81
83
  torchzero/modules/quasi_newton/__init__.py,sha256=0iOlX73PHj9lQS3_2cJ5lyCdas904MnFfIvR8Popvzw,402
82
- torchzero/modules/quasi_newton/cg.py,sha256=h-di1oKKP1tDoh-LogBRIRCp2UF9GA6XjEJPlX6xXf4,9322
83
- torchzero/modules/quasi_newton/lbfgs.py,sha256=jtO5ldbx66yUWv-20c-4mvq6HhCMuomCwJK8A8bjcYA,9168
84
- torchzero/modules/quasi_newton/lsr1.py,sha256=F_DtMQZfQSjmSLjnx4nw16AV7qCdNxT9ITQbfNFrPdM,5879
84
+ torchzero/modules/quasi_newton/cg.py,sha256=lIJvfWAZ08r0o4uqaJnRG6pvcE2kBkJUkZ1MK37KMTk,9602
85
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=SMgesPMZ4ubVeG7R395SnAb5ffkyPHbzSQMqPlLGI7U,9211
86
+ torchzero/modules/quasi_newton/lsr1.py,sha256=XmYyYANzQgQuFtOMW59znQrS-mprGRXazicfB9JAup8,6059
85
87
  torchzero/modules/quasi_newton/olbfgs.py,sha256=2YAOXlMnPGw22sNcIMH1hmggzAXQRbN59RSPUZNKUZY,8352
86
- torchzero/modules/quasi_newton/quasi_newton.py,sha256=jwQkzlnozIaxHW9kuDAAlME0YuQdrdZX9OZZoTmej4Q,17384
88
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=rUp4s3MbACcOjwpz00TAjl-olif50voTmC16vv5XrSE,17496
87
89
  torchzero/modules/quasi_newton/experimental/__init__.py,sha256=3qpZGgdsx6wpoafWaNWx-eamRl1FuxVCWQZq8Y7Cl98,39
88
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=PlyuIH2pFazIR89OGTrZESt752GkbArh_Zb8mtVCOi0,10731
89
- torchzero/modules/second_order/__init__.py,sha256=5lRmwIU53eRc1owpOZ5FMDc7u1Z48I3PDc0NyCBaJNM,113
90
- torchzero/modules/second_order/newton.py,sha256=XNhscAuWwxOUwps3sUrxc2ExgkNFbilnAdszrCvQxFg,5845
91
- torchzero/modules/second_order/newton_cg.py,sha256=lUVn4-ZoW3qAxqEy8i7yz_aN7sZDoQChd-A_Ubrz-Ag,2871
92
- torchzero/modules/second_order/nystrom.py,sha256=ZyCWrde-_-Ednj46jafuvBOzG3nC-3cPYGr-HytZbsE,6073
90
+ torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=ec6JKYX89xA_UlY9VrMB3hBjDyNKwkalS_4JQGA1qOY,10762
91
+ torchzero/modules/second_order/__init__.py,sha256=jolCGaIVkID9hpxgx0Tc22wgjVlwuWekWjKTMe5jKXw,114
92
+ torchzero/modules/second_order/newton.py,sha256=xxkrhFK4i5I9oOX3AGGh_6bXNDUSFq4D0pw3c7qgEd8,5925
93
+ torchzero/modules/second_order/newton_cg.py,sha256=PILHRf2koop_cywE1RNGukT16alDO7prC4C3HlZcW30,2861
94
+ torchzero/modules/second_order/nystrom.py,sha256=zdLSTQ_S5VViUt2sAmFNoDCCHKmHP2A7112czkZNlUk,6051
93
95
  torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
94
96
  torchzero/modules/smoothing/gaussian.py,sha256=YlT_G4MqAVkiWG56RHAwgt5SSPISpvQZQbSLh8mhF3I,6153
95
97
  torchzero/modules/smoothing/laplacian.py,sha256=Bfrs7D59SfdU7j-97UBKD1hs0obC-ZgjJvG7oKwaa0o,5065
@@ -116,13 +118,14 @@ torchzero/utils/python_tools.py,sha256=RFBqNj8w52dpJ983pUPPDbg2x1MX_-SsBnBMffWGG
116
118
  torchzero/utils/tensorlist.py,sha256=qSbiliVo1euFAksdHHHRbPUdYYxfkw1dvhpXj71wGy0,53162
117
119
  torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
118
120
  torchzero/utils/linalg/__init__.py,sha256=Dzbho3_z7JDdKzYD-QdLArg0ZEoC2BVGdlE3JoAnXHQ,272
121
+ torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
119
122
  torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
120
123
  torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
121
124
  torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
122
- torchzero/utils/linalg/solve.py,sha256=hN450ONzAirYOvWF2g0E0Wy2n1bCw4X-KXWi6p4jvDM,5136
125
+ torchzero/utils/linalg/solve.py,sha256=P0PMi0zro3G3Rd0X-JeoLk7tqYDB0js0aB4bpQ0OABU,5235
123
126
  torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
124
- torchzero-0.3.6.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
125
- torchzero-0.3.6.dist-info/METADATA,sha256=wjXJuO_WRQYv15BSA_9yo2qe2xe7jET7YOy8xb9YmnE,13944
126
- torchzero-0.3.6.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
127
- torchzero-0.3.6.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
128
- torchzero-0.3.6.dist-info/RECORD,,
127
+ torchzero-0.3.9.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
128
+ torchzero-0.3.9.dist-info/METADATA,sha256=aENIaMgy94tD6nakRWfApleVSy6bxW8-q3-mQeVSeGA,13941
129
+ torchzero-0.3.9.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
130
+ torchzero-0.3.9.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
131
+ torchzero-0.3.9.dist-info/RECORD,,