torchzero 0.3.5__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. tests/test_opts.py +1 -1
  2. tests/test_tensorlist.py +17 -17
  3. torchzero/core/preconditioner.py +11 -10
  4. torchzero/modules/experimental/__init__.py +3 -2
  5. torchzero/modules/experimental/absoap.py +8 -2
  6. torchzero/modules/experimental/adadam.py +1 -1
  7. torchzero/modules/experimental/adamY.py +1 -1
  8. torchzero/modules/experimental/adasoap.py +1 -1
  9. torchzero/modules/experimental/algebraic_newton.py +1 -1
  10. torchzero/modules/experimental/curveball.py +1 -1
  11. torchzero/modules/experimental/gradmin.py +1 -1
  12. torchzero/modules/experimental/newton_solver.py +88 -0
  13. torchzero/modules/experimental/{dsoap.py → soapy.py} +4 -4
  14. torchzero/modules/experimental/spectral.py +5 -3
  15. torchzero/modules/experimental/subspace_preconditioners.py +16 -9
  16. torchzero/modules/optimizers/soap.py +1 -2
  17. torchzero/modules/projections/projection.py +27 -1
  18. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +4 -3
  19. torchzero/modules/quasi_newton/lbfgs.py +4 -3
  20. torchzero/modules/quasi_newton/lsr1.py +6 -3
  21. torchzero/modules/quasi_newton/quasi_newton.py +16 -17
  22. torchzero/modules/second_order/__init__.py +1 -1
  23. torchzero/modules/second_order/newton_cg.py +1 -1
  24. torchzero/utils/linalg/benchmark.py +20 -0
  25. torchzero/utils/linalg/solve.py +15 -14
  26. {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/METADATA +2 -2
  27. {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/RECORD +30 -28
  28. {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/WHEEL +0 -0
  29. {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/licenses/LICENSE +0 -0
  30. {torchzero-0.3.5.dist-info → torchzero-0.3.8.dist-info}/top_level.txt +0 -0
tests/test_opts.py CHANGED
@@ -745,7 +745,7 @@ SSVM = Run(
745
745
  func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
746
  sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
747
  needs_closure=True,
748
- func='rosen', steps=50, loss=1e-12, merge_invariant=True,
748
+ func='rosen', steps=50, loss=0.02, merge_invariant=True,
749
749
  sphere_steps=10, sphere_loss=0,
750
750
  )
751
751
 
tests/test_tensorlist.py CHANGED
@@ -1301,7 +1301,7 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
1301
1301
  expected_tl = TensorList(expected_list)
1302
1302
  assert isinstance(result, TensorList)
1303
1303
  assert len(result) == len(expected_tl)
1304
- assert_tl_allclose(result, expected_tl, atol=1e-6) # Use allclose due to potential float variations
1304
+ assert_tl_allclose(result, expected_tl, atol=1e-3) # Use allclose due to potential float variations
1305
1305
 
1306
1306
  # --- Grafting, Rescaling, Normalizing, Clipping ---
1307
1307
 
@@ -1381,8 +1381,8 @@ def test_rescale(simple_tl: TensorList, dim):
1381
1381
  assert torch.allclose(rescaled_scalar.global_min(), torch.tensor(min_val))
1382
1382
  assert torch.allclose(rescaled_scalar.global_max(), torch.tensor(max_val))
1383
1383
  else:
1384
- assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-4)
1385
- assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-4)
1384
+ assert_tl_allclose(rescaled_scalar_min, TensorList([torch.full_like(t, min_val) for t in rescaled_scalar_min]),atol=1e-3)
1385
+ assert_tl_allclose(rescaled_scalar_max, TensorList([torch.full_like(t, max_val) for t in rescaled_scalar_max]),atol=1e-3)
1386
1386
 
1387
1387
 
1388
1388
  # Rescale list
@@ -1402,8 +1402,8 @@ def test_rescale(simple_tl: TensorList, dim):
1402
1402
  assert global_max_rescaled < avg_max + 1.0 # Loose check
1403
1403
 
1404
1404
  else:
1405
- assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-4)
1406
- assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-4)
1405
+ assert_tl_allclose(rescaled_list_min, TensorList([torch.full_like(t, mn) for t, mn in zip(rescaled_list_min, min_list)]),atol=1e-3)
1406
+ assert_tl_allclose(rescaled_list_max, TensorList([torch.full_like(t, mx) for t, mx in zip(rescaled_list_max, max_list)]),atol=1e-3)
1407
1407
 
1408
1408
  # Rescale to 01 helper
1409
1409
  rescaled_01 = simple_tl.rescale_to_01(dim=dim, eps=eps)
@@ -1413,8 +1413,8 @@ def test_rescale(simple_tl: TensorList, dim):
1413
1413
  assert torch.allclose(rescaled_01.global_min(), torch.tensor(0.0))
1414
1414
  assert torch.allclose(rescaled_01.global_max(), torch.tensor(1.0))
1415
1415
  else:
1416
- assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-4)
1417
- assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-4)
1416
+ assert_tl_allclose(rescaled_01_min, TensorList([torch.zeros_like(t) for t in rescaled_01_min]), atol=1e-3)
1417
+ assert_tl_allclose(rescaled_01_max, TensorList([torch.ones_like(t) for t in rescaled_01_max]), atol=1e-3)
1418
1418
 
1419
1419
 
1420
1420
  # Test inplace
@@ -1454,11 +1454,11 @@ def test_normalize(big_tl: TensorList, dim):
1454
1454
  normalized_scalar_var = normalized_scalar.var(dim=dim if dim != 'global' else None)
1455
1455
 
1456
1456
  if dim == 'global':
1457
- assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-4)
1458
- assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-4)
1457
+ assert torch.allclose(normalized_scalar.global_mean(), torch.tensor(mean_val), atol=1e-3)
1458
+ assert torch.allclose(normalized_scalar.global_var(), torch.tensor(var_val), atol=1e-3)
1459
1459
  else:
1460
- assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-4)
1461
- assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-4)
1460
+ assert_tl_allclose(normalized_scalar_mean, TensorList([torch.full_like(t, mean_val) for t in normalized_scalar_mean]), atol=1e-3)
1461
+ assert_tl_allclose(normalized_scalar_var, TensorList([torch.full_like(t, var_val) for t in normalized_scalar_var]), atol=1e-3)
1462
1462
 
1463
1463
  # Normalize list mean/var
1464
1464
  normalized_list = simple_tl.normalize(mean_list, var_list, dim=dim)
@@ -1476,19 +1476,19 @@ def test_normalize(big_tl: TensorList, dim):
1476
1476
  # assert torch.allclose(global_mean_rescaled, torch.tensor(avg_mean), rtol=1e-1, atol=1e-1) # Loose check
1477
1477
  # assert torch.allclose(global_var_rescaled, torch.tensor(avg_var), rtol=1e-1, atol=1e-1) # Loose check
1478
1478
  else:
1479
- assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-4)
1480
- assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-4)
1479
+ assert_tl_allclose(normalized_list_mean, TensorList([torch.full_like(t, m) for t, m in zip(normalized_list_mean, mean_list)]), atol=1e-3)
1480
+ assert_tl_allclose(normalized_list_var, TensorList([torch.full_like(t, v) for t, v in zip(normalized_list_var, var_list)]), atol=1e-3)
1481
1481
 
1482
1482
  # Z-normalize helper
1483
1483
  znorm = simple_tl.znormalize(dim=dim, eps=1e-10)
1484
1484
  znorm_mean = znorm.mean(dim=dim if dim != 'global' else None)
1485
1485
  znorm_var = znorm.var(dim=dim if dim != 'global' else None)
1486
1486
  if dim == 'global':
1487
- assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-4)
1488
- assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-4)
1487
+ assert torch.allclose(znorm.global_mean(), torch.tensor(0.0), atol=1e-3)
1488
+ assert torch.allclose(znorm.global_var(), torch.tensor(1.0), atol=1e-3)
1489
1489
  else:
1490
- assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-4)
1491
- assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-4)
1490
+ assert_tl_allclose(znorm_mean, TensorList([torch.zeros_like(t) for t in znorm_mean]), atol=1e-3)
1491
+ assert_tl_allclose(znorm_var, TensorList([torch.ones_like(t) for t in znorm_var]), atol=1e-3)
1492
1492
 
1493
1493
 
1494
1494
  # Test inplace
@@ -38,7 +38,7 @@ class Preconditioner(Transform):
38
38
 
39
39
 
40
40
  def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
41
- step = self.global_state.get('step', 0)
41
+ step = self.global_state.get('__step', 0)
42
42
  states = [self.state[p] for p in params]
43
43
  settings = [self.settings[p] for p in params]
44
44
  global_settings = settings[0]
@@ -47,8 +47,10 @@ class Preconditioner(Transform):
47
47
  scale_first = global_settings['__scale_first']
48
48
  scale_factor = 0
49
49
  if scale_first and step == 0:
50
- # initial step size guess from pytorch LBFGS
51
- scale_factor = TensorList(tensors).abs().sum()
50
+ # initial step size guess from pytorch LBFGS was too unstable
51
+ # I switched to norm
52
+ tensors = TensorList(tensors)
53
+ scale_factor = tensors.abs().global_mean().clip(min=1)
52
54
 
53
55
  # update preconditioner
54
56
  if step % update_freq == 0:
@@ -65,11 +67,11 @@ class Preconditioner(Transform):
65
67
  if scale_first and step == 0:
66
68
  torch._foreach_div_(tensors, scale_factor)
67
69
 
68
- self.global_state['step'] = step + 1
70
+ self.global_state['__step'] = step + 1
69
71
  return tensors
70
72
 
71
73
  def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
72
- step = self.global_state.get('step', 0)
74
+ step = self.global_state.get('__step', 0)
73
75
  tensors_vec = torch.cat([t.ravel() for t in tensors])
74
76
  params_vec = torch.cat([p.ravel() for p in params])
75
77
  grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
@@ -82,8 +84,8 @@ class Preconditioner(Transform):
82
84
  scale_first = global_settings['__scale_first']
83
85
  scale_factor = 0
84
86
  if scale_first and step == 0:
85
- # initial step size guess from pytorch LBFGS
86
- scale_factor = tensors_vec.abs().sum()
87
+ # initial step size guess from pytorch LBFGS was too unstable
88
+ scale_factor = tensors_vec.abs().mean().clip(min=1)
87
89
 
88
90
  # update preconditioner
89
91
  if step % update_freq == 0:
@@ -99,11 +101,10 @@ class Preconditioner(Transform):
99
101
 
100
102
  # scale initial step, when preconditioner might not have been applied
101
103
  if scale_first and step == 0:
102
- if scale_factor >= torch.finfo(tensors_vec.dtype).eps:
103
- tensors_vec /= scale_factor
104
+ tensors_vec /= scale_factor
104
105
 
105
106
  tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
106
- self.global_state['step'] = step + 1
107
+ self.global_state['__step'] = step + 1
107
108
  return tensors
108
109
 
109
110
  @torch.no_grad
@@ -3,7 +3,7 @@ from .adadam import Adadam
3
3
  from .adamY import AdamY
4
4
  from .adasoap import AdaSOAP
5
5
  from .curveball import CurveBall
6
- from .dsoap import DSOAP
6
+ from .soapy import SOAPY
7
7
  from .gradmin import GradMin
8
8
  from .reduce_outward_lr import ReduceOutwardLR
9
9
  from .spectral import SpectralPreconditioner
@@ -11,4 +11,5 @@ from .subspace_preconditioners import (
11
11
  HistorySubspacePreconditioning,
12
12
  RandomSubspacePreconditioning,
13
13
  )
14
- from .tropical_newton import TropicalNewton
14
+ from .tropical_newton import TropicalNewton
15
+ from .newton_solver import NewtonSolver
@@ -140,11 +140,17 @@ Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
140
140
  class ABSOAP(Transform):
141
141
  """SOAP but with two extra letters included in its name in order to improve converence
142
142
 
143
+ so what you can do is choose what goes into what ,and that is supposed to be good.
144
+
143
145
  new args
144
146
 
145
147
  scale by s whether to scale gradient differences by parameter differences
146
148
 
147
149
  y_to_ema2 whether to use gradient differences for exponential moving average too
150
+
151
+ okay I changed these args into another ones
152
+
153
+ BASICALLY THIS IS FOR MY EXPERIMENTS
148
154
  """
149
155
  def __init__(
150
156
  self,
@@ -213,7 +219,7 @@ class ABSOAP(Transform):
213
219
  if 'g_prev' not in state:
214
220
  state['p_prev'] = p.clone()
215
221
  state['g_prev'] = t.clone()
216
- updates.append(tensors[i].sign())
222
+ updates.append(tensors[i].clip(-0.1,0.1))
217
223
  continue
218
224
 
219
225
  p_prev = state['p_prev']
@@ -285,7 +291,7 @@ class ABSOAP(Transform):
285
291
  state['Q'] = get_orthogonal_matrix(state['GG'])
286
292
 
287
293
  state['step'] = 0
288
- updates.append(tensors[i].sign())
294
+ updates.append(tensors[i].clip(-0.1,0.1))
289
295
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
290
296
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
291
297
 
@@ -50,7 +50,7 @@ def adadam_(
50
50
  return None
51
51
 
52
52
  class Adadam(Module):
53
- """Adam with a diagonally preconditioned preconditioner and a graceful name."""
53
+ """Adam with a diagonally preconditioned preconditioner."""
54
54
  def __init__(
55
55
  self,
56
56
  beta1: float = 0.9,
@@ -37,7 +37,7 @@ def adamy_(
37
37
  p_prev.copy_(p)
38
38
  g_prev.copy_(g)
39
39
 
40
- update = g.sign().lazy_mul_(alpha*0.1)
40
+ update = g.clip(-0.1,0.1).lazy_mul_(alpha)
41
41
  if params_ is None: return update
42
42
  params_.sub_(update)
43
43
  return None
@@ -218,7 +218,7 @@ class AdaSOAP(Transform):
218
218
  state['Q'] = get_orthogonal_matrix(GG_precond)
219
219
 
220
220
  state['step'] = 0
221
- updates.append(tensors[i].sign())
221
+ updates.append(tensors[i].clip(-0.1,0.1))
222
222
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
223
223
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
224
224
 
@@ -71,7 +71,7 @@ def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemir
71
71
 
72
72
 
73
73
  class AlgebraicNewton(Module):
74
- """newton in other algebras, not practical because solving linear system is very hard."""
74
+ """newton in other algebras, not that it works."""
75
75
  def __init__(
76
76
  self,
77
77
  reg: float | None = None,
@@ -13,7 +13,7 @@ def curveball(
13
13
  momentum: float | NumberList,
14
14
  precond_lr: float | NumberList,
15
15
  ):
16
- """returns z_, clone it!!!"""
16
+ """returns z_, clone it!!! (no just negate it)"""
17
17
  delta = Hz + tensors
18
18
  z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
19
19
  return z_
@@ -14,7 +14,7 @@ from ..smoothing.gaussian import Reformulation
14
14
 
15
15
 
16
16
  class GradMin(Reformulation):
17
- """Reformulates the objective to minimize sum of gradient magnitudes via autograd.
17
+ """Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.
18
18
 
19
19
  Args:
20
20
  loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
@@ -0,0 +1,88 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, overload
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, apply, Modular
7
+ from ...utils import TensorList, as_tensorlist
8
+ from ...utils.derivatives import hvp
9
+ from ..quasi_newton import LBFGS
10
+
11
+ class NewtonSolver(Module):
12
+ """Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
13
+ def __init__(
14
+ self,
15
+ solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
16
+ maxiter=None,
17
+ tol=1e-3,
18
+ reg: float = 0,
19
+ warm_start=True,
20
+ inner: Chainable | None = None,
21
+ ):
22
+ defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
23
+ super().__init__(defaults,)
24
+
25
+ if inner is not None:
26
+ self.set_child('inner', inner)
27
+
28
+ @torch.no_grad
29
+ def step(self, vars):
30
+ params = TensorList(vars.params)
31
+ closure = vars.closure
32
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
33
+
34
+ settings = self.settings[params[0]]
35
+ solver_cls = settings['solver']
36
+ maxiter = settings['maxiter']
37
+ tol = settings['tol']
38
+ reg = settings['reg']
39
+ warm_start = settings['warm_start']
40
+
41
+ # ---------------------- Hessian vector product function --------------------- #
42
+ grad = vars.get_grad(create_graph=True)
43
+
44
+ def H_mm(x):
45
+ with torch.enable_grad():
46
+ Hvp = TensorList(hvp(params, grad, x, create_graph=True))
47
+ if reg != 0: Hvp = Hvp + (x*reg)
48
+ return Hvp
49
+
50
+ # -------------------------------- inner step -------------------------------- #
51
+ b = as_tensorlist(grad)
52
+ if 'inner' in self.children:
53
+ b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
54
+
55
+ # ---------------------------------- run cg ---------------------------------- #
56
+ x0 = None
57
+ if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
58
+ if x0 is None: x = b.zeros_like().requires_grad_(True)
59
+ else: x = x0.clone().requires_grad_(True)
60
+
61
+ solver = solver_cls(x)
62
+ def lstsq_closure(backward=True):
63
+ Hx = H_mm(x)
64
+ loss = (Hx-b).pow(2).global_mean()
65
+ if backward:
66
+ solver.zero_grad()
67
+ loss.backward(inputs=x)
68
+ return loss
69
+
70
+ if maxiter is None: maxiter = b.global_numel()
71
+ loss = None
72
+ initial_loss = lstsq_closure(False)
73
+ if initial_loss > tol:
74
+ for i in range(maxiter):
75
+ loss = solver.step(lstsq_closure)
76
+ assert loss is not None
77
+ if min(loss, loss/initial_loss) < tol: break
78
+
79
+ print(f'{loss = }')
80
+
81
+ if warm_start:
82
+ assert x0 is not None
83
+ x0.copy_(x)
84
+
85
+ vars.update = x.detach()
86
+ return vars
87
+
88
+
@@ -3,7 +3,7 @@ from operator import itemgetter
3
3
  import torch
4
4
 
5
5
  from ...core import Chainable, Transform, apply
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
6
+ from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
7
 
8
8
  @torch.no_grad
9
9
  def update_soap_covariances_(
@@ -135,7 +135,7 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
135
135
 
136
136
  return final, exp_avg_sq
137
137
 
138
- class DSOAP(Transform):
138
+ class SOAPY(Transform):
139
139
  """SOAP but uses scaled gradient differences
140
140
 
141
141
  new args
@@ -195,7 +195,7 @@ class DSOAP(Transform):
195
195
  if 'g_prev' not in state:
196
196
  state['p_prev'] = p.clone()
197
197
  state['g_prev'] = t.clone()
198
- updates.append(tensors[i].sign())
198
+ updates.append(tensors[i].clip(-0.1,0.1))
199
199
  continue
200
200
 
201
201
  p_prev = state['p_prev']
@@ -228,7 +228,7 @@ class DSOAP(Transform):
228
228
  state['Q'] = get_orthogonal_matrix(state['GG'])
229
229
 
230
230
  state['step'] = 0
231
- updates.append(tensors[i].sign())
231
+ updates.append(tensors[i].clip(-0.1,0.1))
232
232
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
233
233
  # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
234
234
 
@@ -194,8 +194,10 @@ class SpectralPreconditioner(TensorwisePreconditioner):
194
194
  order (int, optional):
195
195
  whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
196
196
  solver (str, optional): what to use for whitening. Defaults to 'svd'.
197
- U_beta (float | None, optional): beta for U (probably a bad idea). Defaults to None.
198
- S_beta (float | None, optional): beta for S (probably a bad idea). Defaults to None.
197
+ A_beta (float | None, optional):
198
+ beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
+ B_beta (float | None, optional):
200
+ beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
201
  interval (int, optional): How often to update history. Defaults to 1 (every step).
200
202
  concat_params (bool, optional):
201
203
  whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
@@ -275,7 +277,7 @@ class SpectralPreconditioner(TensorwisePreconditioner):
275
277
  A = state.get('A', None)
276
278
  if A is None:
277
279
  # make a conservative step to avoid issues due to different GD scaling
278
- return tensor.div_(max(1, tensor.abs().sum())) # pyright:ignore[reportArgumentType]
280
+ return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
279
281
 
280
282
  B = state['B']
281
283
  update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
@@ -38,16 +38,19 @@ def apply_subspace_preconditioner(
38
38
  return basis @ update_projected # d
39
39
 
40
40
  class RandomSubspacePreconditioning(Transform):
41
- """full matrix rmsprop in random subspace"""
42
- def __init__(self, k: int, beta: float | None = 0.99):
43
- defaults = dict(k=k, beta=beta)
41
+ """full matrix rmsprop in random slowly changing subspace"""
42
+ def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
+ defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
44
  super().__init__(defaults, uses_grad=False)
45
45
 
46
+ if inner is not None: self.set_child('inner', inner)
47
+
46
48
  def transform(self, tensors, params, grads, vars):
47
49
  settings = self.settings[params[0]]
48
50
  g = torch.cat([t.view(-1) for t in tensors])
49
51
  k = settings['k']
50
52
  beta = settings['beta']
53
+ basis_beta = settings['basis_beta']
51
54
 
52
55
  if 'basis' not in self.global_state:
53
56
  self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
@@ -56,13 +59,19 @@ class RandomSubspacePreconditioning(Transform):
56
59
  basis = self.global_state['basis']
57
60
  accumulator = self.global_state['accumulator']
58
61
 
62
+ if basis_beta is not None:
63
+ basis.lerp_(torch.randn_like(basis), 1-basis_beta)
64
+
59
65
  update_subspace_preconditioner_(g, basis, accumulator, beta)
66
+
67
+ if 'inner' in self.children:
68
+ tensors = apply(self.children['inner'], tensors, params, grads, vars)
69
+ g = torch.cat([t.view(-1) for t in tensors])
70
+
60
71
  try:
61
72
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
62
73
  except torch.linalg.LinAlgError:
63
- denom = g.abs().sum()
64
- if denom <= 1e-10: denom = torch.ones_like(denom)
65
- preconditioned = g / g.abs().sum()
74
+ preconditioned = g.clip(-0.1, 0.1)
66
75
  vec_to_tensors_(preconditioned, tensors)
67
76
 
68
77
  return tensors
@@ -119,9 +128,7 @@ class HistorySubspacePreconditioning(Transform):
119
128
  try:
120
129
  preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
121
130
  except torch.linalg.LinAlgError:
122
- denom = g.abs().sum()
123
- if denom <= 1e-10: denom = torch.ones_like(denom)
124
- preconditioned = g / g.abs().sum()
131
+ preconditioned = g.clip(-0.1,0.1)
125
132
  vec_to_tensors_(preconditioned, tensors)
126
133
 
127
134
  return tensors
@@ -222,8 +222,7 @@ class SOAP(Transform):
222
222
  state['Q'] = get_orthogonal_matrix(state['GG'])
223
223
 
224
224
  state['step'] = 0
225
- updates.append(tensors[i].sign().div_(10))
226
- # updates.append(tensors[i] / tensors[i].abs().sum())
225
+ updates.append(tensors[i].clip(-0.1, 0.1))
227
226
  continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
228
227
  # I use scaled update instead as to not mess up with next modules.
229
228
 
@@ -1,4 +1,5 @@
1
1
  import math
2
+ from functools import partial
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Iterable
4
5
  from typing import Any, Literal
@@ -33,6 +34,25 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
33
34
 
34
35
  return projected_closure
35
36
 
37
+ def _projected_get_grad_override(
38
+ retain_graph: bool | None = None,
39
+ create_graph: bool = False,
40
+ projection: Any = ...,
41
+ unprojected_vars: Any = ...,
42
+ self: Any = ...,
43
+ ):
44
+ assert isinstance(projection, Projection)
45
+ assert isinstance(unprojected_vars, Vars)
46
+ assert isinstance(self, Vars)
47
+
48
+ if self.grad is not None: return self.grad
49
+ grads = unprojected_vars.get_grad(retain_graph, create_graph)
50
+ projected_grads = list(projection.project(grads, self, current='grads'))
51
+ self.grad = projected_grads
52
+ for p, g in zip(self.params, projected_grads):
53
+ p.grad = g
54
+ return self.grad
55
+
36
56
 
37
57
  class Projection(Module, ABC):
38
58
  """
@@ -137,6 +157,12 @@ class Projection(Module, ABC):
137
157
 
138
158
  # step
139
159
  projected_vars.params = self._projected_params
160
+ projected_vars.get_grad = partial(
161
+ _projected_get_grad_override,
162
+ projection=self,
163
+ unprojected_vars=vars,
164
+ self=projected_vars,
165
+ )
140
166
  projected_vars = self.children['modules'].step(projected_vars)
141
167
 
142
168
  # empty fake params storage
@@ -149,7 +175,7 @@ class Projection(Module, ABC):
149
175
  unprojected_vars = projected_vars.clone(clone_update=False)
150
176
  unprojected_vars.closure = vars.closure
151
177
  unprojected_vars.params = vars.params
152
- if unprojected_vars.grad is None: unprojected_vars.grad = vars.grad
178
+ unprojected_vars.grad = vars.grad
153
179
 
154
180
  if self._project_update:
155
181
  assert projected_vars.update is not None
@@ -37,10 +37,11 @@ def lbfgs(
37
37
  z_tfm: Any,
38
38
  ):
39
39
  if len(s_history) == 0 or y_k is None or ys_k is None:
40
- # dir = params.grad.sign() # may work fine
41
40
 
42
- # initial step size guess taken from pytorch L-BFGS
43
- return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
41
+ # initial step size guess modified from pytorch L-BFGS
42
+ scale = 1 / tensors_.abs().global_sum()
43
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
44
+ return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
44
45
 
45
46
  else:
46
47
  # 1st loop
@@ -36,10 +36,11 @@ def lbfgs(
36
36
  step: int,
37
37
  ):
38
38
  if len(s_history) == 0 or y_k is None or ys_k is None:
39
- # dir = params.grad.sign() # may work fine
40
39
 
41
- # initial step size guess taken from pytorch L-BFGS
42
- return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
40
+ # initial step size guess modified from pytorch L-BFGS
41
+ scale = 1 / tensors_.abs().global_sum()
42
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
43
+ return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
43
44
 
44
45
  else:
45
46
  # 1st loop
@@ -17,8 +17,9 @@ def lsr1_(
17
17
  ):
18
18
  if step == 0 or not s_history:
19
19
  # initial step size guess from pytorch
20
- tensors_.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
21
- return tensors_
20
+ scale = 1 / tensors_.abs().global_sum()
21
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
22
+ return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
22
23
 
23
24
  m = len(s_history)
24
25
 
@@ -64,7 +65,9 @@ def lsr1_(
64
65
  Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
65
66
 
66
67
  if scale_second and step == 1:
67
- Hx.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
68
+ scale = 1 / tensors_.abs().global_sum()
69
+ if scale < 1e-5: scale = 1 / tensors_.abs().mean()
70
+ Hx.mul_(min(1.0, scale)) # pyright:ignore[reportArgumentType]
68
71
  return Hx
69
72
 
70
73
 
@@ -68,6 +68,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
68
68
  M_key = 'H' if inverse else 'B'
69
69
  M = state.get(M_key, None)
70
70
  step = state.get('step', 0)
71
+ state['step'] = step + 1
71
72
  init_scale = settings['init_scale']
72
73
  tol = settings['tol']
73
74
  tol_reset = settings['tol_reset']
@@ -91,13 +92,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
91
92
  state['p_prev'].copy_(p)
92
93
  state['g_prev'].copy_(g)
93
94
 
94
-
95
- if reset_interval is not None and step % reset_interval == 0:
95
+ if reset_interval is not None and step != 0 and step % reset_interval == 0:
96
96
  self._reset_M_(M, s, y, inverse, init_scale)
97
97
  return
98
98
 
99
99
  # tolerance on gradient difference to avoid exploding after converging
100
- if y.abs().max() <= tol:
100
+ elif y.abs().max() <= tol:
101
101
  # reset history
102
102
  if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
103
103
  return
@@ -119,11 +119,10 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
119
119
 
120
120
  @torch.no_grad
121
121
  def apply_tensor(self, tensor, param, grad, state, settings):
122
- step = state['step'] = state.get('step', 0) + 1
122
+ step = state.get('step', 0)
123
123
 
124
124
  if settings['scale_second'] and step == 2:
125
- s = max(1, tensor.abs().sum()) # pyright:ignore[reportArgumentType]
126
- if s < settings['tol']: tensor = tensor/s
125
+ tensor = tensor / tensor.abs().mean().clip(min=1)
127
126
 
128
127
  inverse = settings['inverse']
129
128
  if inverse:
@@ -135,7 +134,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
135
134
  return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
136
135
 
137
136
  # to avoid typing all arguments for each method
138
- class QuasiNewtonH(HessianUpdateStrategy):
137
+ class HUpdateStrategy(HessianUpdateStrategy):
139
138
  def __init__(
140
139
  self,
141
140
  init_scale: float | Literal["auto"] = "auto",
@@ -174,7 +173,7 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
174
173
  H += term1.sub_(term2)
175
174
  return H
176
175
 
177
- class BFGS(QuasiNewtonH):
176
+ class BFGS(HUpdateStrategy):
178
177
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
179
178
  return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
180
179
 
@@ -193,7 +192,7 @@ def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
193
192
  H += torch.outer(z, z).div_(denom)
194
193
  return H
195
194
 
196
- class SR1(QuasiNewtonH):
195
+ class SR1(HUpdateStrategy):
197
196
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
198
197
  return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
199
198
 
@@ -213,7 +212,7 @@ def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
213
212
  H += term1.sub_(term2)
214
213
  return H
215
214
 
216
- class DFP(QuasiNewtonH):
215
+ class DFP(HUpdateStrategy):
217
216
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
218
217
  return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
219
218
 
@@ -254,19 +253,19 @@ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
254
253
  H -= num/denom
255
254
  return H
256
255
 
257
- class BroydenGood(QuasiNewtonH):
256
+ class BroydenGood(HUpdateStrategy):
258
257
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
259
258
  return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
260
259
 
261
- class BroydenBad(QuasiNewtonH):
260
+ class BroydenBad(HUpdateStrategy):
262
261
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
263
262
  return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
264
263
 
265
- class Greenstadt1(QuasiNewtonH):
264
+ class Greenstadt1(HUpdateStrategy):
266
265
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
267
266
  return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
268
267
 
269
- class Greenstadt2(QuasiNewtonH):
268
+ class Greenstadt2(HUpdateStrategy):
270
269
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
271
270
  return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
272
271
 
@@ -287,7 +286,7 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
287
286
  H[:, j] += num.squeeze() / denom
288
287
  return H
289
288
 
290
- class ColumnUpdatingMethod(QuasiNewtonH):
289
+ class ColumnUpdatingMethod(HUpdateStrategy):
291
290
  """Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
292
291
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
293
292
  return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
@@ -307,7 +306,7 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
307
306
  H -= num/denom
308
307
  return H, R
309
308
 
310
- class ThomasOptimalMethod(QuasiNewtonH):
309
+ class ThomasOptimalMethod(HUpdateStrategy):
311
310
  """Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
312
311
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
313
312
  if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
@@ -364,7 +363,7 @@ def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
364
363
  H += num.div_(sy)
365
364
  return H
366
365
 
367
- class Pearson2(QuasiNewtonH):
366
+ class Pearson2(HUpdateStrategy):
368
367
  """finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
369
368
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
370
369
  return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
@@ -1,3 +1,3 @@
1
1
  from .newton import Newton
2
2
  from .newton_cg import NewtonCG
3
- from .nystrom import NystromSketchAndSolve, NystromPCG
3
+ from .nystrom import NystromSketchAndSolve, NystromPCG
@@ -76,7 +76,7 @@ class NewtonCG(Module):
76
76
  x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
77
77
  if warm_start:
78
78
  assert x0 is not None
79
- x0.set_(x)
79
+ x0.copy_(x)
80
80
 
81
81
  vars.update = x
82
82
  return vars
@@ -0,0 +1,20 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+
5
+
6
+ def benchmark_solver(
7
+ A: torch.Tensor | Callable[[torch.Tensor], torch.Tensor],
8
+ b: torch.Tensor,
9
+ solver: Callable[[Callable[[torch.Tensor], torch.Tensor], torch.Tensor]]
10
+ ):
11
+ residuals = []
12
+ def A_mm(x):
13
+ if callable(A): Ax = A(x)
14
+ else: Ax = A@x
15
+ residuals.append(torch.linalg.vector_norm(Ax-b)) # pylint:disable=not-callable
16
+ return Ax
17
+
18
+ solver(A_mm, b)
19
+ return residuals
20
+
@@ -8,27 +8,27 @@ from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_nume
8
8
  def cg(
9
9
  A_mm: Callable[[torch.Tensor], torch.Tensor],
10
10
  b: torch.Tensor,
11
- x0_: torch.Tensor | None,
12
- tol: float | None,
13
- maxiter: int | None,
11
+ x0_: torch.Tensor | None = None,
12
+ tol: float | None = 1e-4,
13
+ maxiter: int | None = None,
14
14
  reg: float = 0,
15
15
  ) -> torch.Tensor: ...
16
16
  @overload
17
17
  def cg(
18
18
  A_mm: Callable[[TensorList], TensorList],
19
19
  b: TensorList,
20
- x0_: TensorList | None,
21
- tol: float | None,
22
- maxiter: int | None,
20
+ x0_: TensorList | None = None,
21
+ tol: float | None = 1e-4,
22
+ maxiter: int | None = None,
23
23
  reg: float | list[float] | tuple[float] = 0,
24
24
  ) -> TensorList: ...
25
25
 
26
26
  def cg(
27
27
  A_mm: Callable,
28
28
  b: torch.Tensor | TensorList,
29
- x0_: torch.Tensor | TensorList | None,
30
- tol: float | None,
31
- maxiter: int | None,
29
+ x0_: torch.Tensor | TensorList | None = None,
30
+ tol: float | None = 1e-4,
31
+ maxiter: int | None = None,
32
32
  reg: float | list[float] | tuple[float] = 0,
33
33
  ):
34
34
  def A_mm_reg(x): # A_mm with regularization
@@ -90,7 +90,7 @@ def nystrom_sketch_and_solve(
90
90
  A_mm: Callable[[torch.Tensor], torch.Tensor],
91
91
  b: torch.Tensor,
92
92
  rank: int,
93
- reg: float,
93
+ reg: float = 1e-3,
94
94
  generator=None,
95
95
  ) -> torch.Tensor:
96
96
  U, lambd = nystrom_approximation(
@@ -116,10 +116,10 @@ def nystrom_pcg(
116
116
  A_mm: Callable[[torch.Tensor], torch.Tensor],
117
117
  b: torch.Tensor,
118
118
  sketch_size: int,
119
- reg: float,
120
- x0_: torch.Tensor | None,
121
- tol: float | None,
122
- maxiter: int | None,
119
+ reg: float = 1e-6,
120
+ x0_: torch.Tensor | None = None,
121
+ tol: float | None = 1e-4,
122
+ maxiter: int | None = None,
123
123
  generator=None,
124
124
  ) -> torch.Tensor:
125
125
  U, lambd = nystrom_approximation(
@@ -166,3 +166,4 @@ def nystrom_pcg(
166
166
  z = P_inv @ residual
167
167
  beta = residual.dot(z) / rz
168
168
  p = z + p*beta
169
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.5
3
+ Version: 0.3.8
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -156,7 +156,7 @@ for epoch in range(100):
156
156
  * `Newton`: Classic Newton's method.
157
157
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
158
158
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
- * `NystromPCG`: NewtonCG with Nyström preconditioning (my current recommendation).
159
+ * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
160
160
 
161
161
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
162
  * `LBFGS`: Limited-memory BFGS.
@@ -1,14 +1,14 @@
1
1
  docs/source/conf.py,sha256=jd80ZT2IdCx7nlQrpOTJL8UhGBNm6KYyXlpp0jmRiAw,1849
2
2
  tests/test_identical.py,sha256=NZ7A8Rm1U9Q16d-cG2G_wccpPtNALyoKYJt9qMownMc,11568
3
3
  tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
4
- tests/test_opts.py,sha256=oDZVFr9AE9ZhyR-sImSgNzQsbPsUtJLzuLd1Nxgkp1w,40850
5
- tests/test_tensorlist.py,sha256=6JTbhvABzXLpbYD-1m3YyPk_KHREMEOTSg4gGpJLuNc,72427
4
+ tests/test_opts.py,sha256=XfpDaVwOC2VuG700BXWAFWiemeVW0ucLG74yfns9mB8,40849
5
+ tests/test_tensorlist.py,sha256=VWX9wYdfkG-0Y8I0wWPp56ZJM0mBNPvS_SC3irmcYcs,72427
6
6
  tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
7
7
  tests/test_vars.py,sha256=3p9dsHk7SJpMd-WRD0ziBNq5FEHRBJGSxbMLD8ES4J0,6815
8
8
  torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
9
9
  torchzero/core/__init__.py,sha256=2JRyeGZprTexAeEPQOIl9fLFGBwzvya-AwKyt7XAmGQ,210
10
10
  torchzero/core/module.py,sha256=Razw3c71Kfegznm0vQxsii1KuTUCPBC9UGyq2v-KX4M,27568
11
- torchzero/core/preconditioner.py,sha256=rMYusKbaypm5K0Ii9VdjKhxi2YWNQbBk9f6AV_MJulY,6191
11
+ torchzero/core/preconditioner.py,sha256=R1IGk7Tbea5wSkazpnXwusjvBxzJHzEWgCtR_nEz2w4,6258
12
12
  torchzero/core/transform.py,sha256=ajNJcX45ds-_lc5CqxgLfEFGil6_BYLerB0WvoTi8rM,10303
13
13
  torchzero/modules/__init__.py,sha256=BDeyuSd2s1WFUUXIo3tGTNp4aYp4A2B94cydpPW24nY,332
14
14
  torchzero/modules/functional.py,sha256=HXNzmPe7LsPadryEm7zrcEKqGej16QDwSgBkbEvggFM,6492
@@ -16,18 +16,19 @@ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLB
16
16
  torchzero/modules/clipping/clipping.py,sha256=I-5utyrqdKtF5yaH-9m2F3UqdfpPmA2bSSFUAZ_d60Q,12544
17
17
  torchzero/modules/clipping/ema_clipping.py,sha256=pLeNuEBLpJ74io2sHn_ZVYaQ6ydEfhpVfVEX2bFttd0,5947
18
18
  torchzero/modules/clipping/growth_clipping.py,sha256=OD-kdia2Rn-DvYlYV6EZlGPDVTh9tj-W9mpiZPc3hOQ,6772
19
- torchzero/modules/experimental/__init__.py,sha256=sJ6URgX35P3zJ2ugBKgAcwBWmdBmAPDW3vXHQ0sK-ro,443
20
- torchzero/modules/experimental/absoap.py,sha256=XUHr5SeLdhLW2kMvWea5xAqZeuJBDQoO4zprDxs4bgU,13317
21
- torchzero/modules/experimental/adadam.py,sha256=W7rRXYJ9tGrzqD_FdFX00HBLuWOEr2tHtfshf6lDFYE,4049
22
- torchzero/modules/experimental/adamY.py,sha256=FoSn-qMI5_BdqZH10WGKkl-zYTPESBdGZ9lfhyqnbB0,4591
23
- torchzero/modules/experimental/adasoap.py,sha256=07gPdEdBIKtmdmSzTGtTO0c2ZkS_otVLufQ76okBjHY,11239
24
- torchzero/modules/experimental/algebraic_newton.py,sha256=_XFYR6bdHWgA5bozxc9AJYteBIAnHrSLgo_bSaZ13eg,5193
25
- torchzero/modules/experimental/curveball.py,sha256=Nw9jtSp5QNj7-FN3qshjYEDHc68LwRLha-Co78mfR5w,3242
26
- torchzero/modules/experimental/dsoap.py,sha256=BEZDw3_n5VDhu7VLgkoSN4rI9JeBdGoO9gFZfqsh74M,10983
27
- torchzero/modules/experimental/gradmin.py,sha256=55dpBDNyrYJusluFhw-v1BXuj1UxER7pNEPTtwYKD4E,3648
19
+ torchzero/modules/experimental/__init__.py,sha256=fEPDYDl7qhaFoferDRmG3ehwuqSvx4Vt2uOz0Y7h4to,483
20
+ torchzero/modules/experimental/absoap.py,sha256=Z4MS4pDPSQ9IaTk8g57OfrsWcYVOT72x533KKtn2Zxk,13512
21
+ torchzero/modules/experimental/adadam.py,sha256=OAPF1-NUbg79V3QOTYzsQlRC97C7XHj5boOLDqLz3PE,4029
22
+ torchzero/modules/experimental/adamY.py,sha256=g1pAHwgdyDdKvObZ67lCSc36L99tl5jlQgOr4lMJCDo,4595
23
+ torchzero/modules/experimental/adasoap.py,sha256=PduidICEGYICIvlYysYCFZF7-QhNX0YlhHfPhLONnUs,11247
24
+ torchzero/modules/experimental/algebraic_newton.py,sha256=sq5ZD_j_EtlxIjNnS0rKKwTSG_JuwsZOg9ZMMQTuQm0,5154
25
+ torchzero/modules/experimental/curveball.py,sha256=Uk30uLEztTHD5IUJLJm9Nn3x31DF9kQHmeLFhc065us,3262
26
+ torchzero/modules/experimental/gradmin.py,sha256=iJmEvDEdVdck0C-94pY3iGxnIoNv6Fu6vj3f7lS6aQM,3686
27
+ torchzero/modules/experimental/newton_solver.py,sha256=iGI2LHLaZd2ovpbq1Vogs76os0zWG7VwM7nUz8RzxVg,3071
28
28
  torchzero/modules/experimental/reduce_outward_lr.py,sha256=kjtRwepBGBca77ToM-lw3b8ywptMtmSdC_jQfjJAwlY,1184
29
- torchzero/modules/experimental/spectral.py,sha256=D3_nCI8teFirCdnnLprNnZ3G1gsOB6RUBWCeDbwi7P0,12043
30
- torchzero/modules/experimental/subspace_preconditioners.py,sha256=4SRJOyTG-fJCGunHR62aRrzw3qFmeI6fRQAYHIadhWw,4682
29
+ torchzero/modules/experimental/soapy.py,sha256=Ishd2Jj6BbhjrLyC48zf-cjMmA1kJb_uKXESQBIML_s,10990
30
+ torchzero/modules/experimental/spectral.py,sha256=8_n208V2yPY3z5pCym-FvwO7DGFhozNgWlpIBtQSdrI,12139
31
+ torchzero/modules/experimental/subspace_preconditioners.py,sha256=WnHpga7Kx4-N2xU5vP3uUHRER70ymyNJCWbSx2zXWOk,4976
31
32
  torchzero/modules/experimental/tropical_newton.py,sha256=uq66ouhgrgc8iYGozDQ3_rtbubj8rKRwb1jfcdnlpHg,4903
32
33
  torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
33
34
  torchzero/modules/grad_approximation/fdm.py,sha256=2PNNBIMup1xlOwLFAwAS3xAVd-7GGVyerMeKH1ug9LQ,3591
@@ -70,25 +71,25 @@ torchzero/modules/optimizers/orthograd.py,sha256=5BLnNJTYuGUClHmlxaXZ1jNvBR4zSFD
70
71
  torchzero/modules/optimizers/rmsprop.py,sha256=d10Y9Ck-391tVysO3xMHg3g2Pe0UEZplgebEyDYi3Z4,4333
71
72
  torchzero/modules/optimizers/rprop.py,sha256=n4k5-9F3ppH0Xl-4l4vNXfqVf2r67vMPCkstUaQKPLw,10974
72
73
  torchzero/modules/optimizers/shampoo.py,sha256=AHHV6d71DqKDPCg52ShWIPIRSGtWkMc1v1XwXgDG3qY,8606
73
- torchzero/modules/optimizers/soap.py,sha256=HL1YrfiEiRMh6aW9D5UEZXBjo3yMTqnpKPHXVD8fOa8,11590
74
+ torchzero/modules/optimizers/soap.py,sha256=Kf2BAtIf2QY1V2ZJcUjRLcp2WfIVLd3mNclnaT3Nmds,11520
74
75
  torchzero/modules/optimizers/sophia_h.py,sha256=8pSlYVm66xWplzdP8MX3MCTzzIYHsxGzDEXJKA03Zgg,4279
75
76
  torchzero/modules/projections/__init__.py,sha256=OCxlh_-Tx-xpl31X03CeFJH9XveH563oEsWc8rUvX0A,196
76
77
  torchzero/modules/projections/dct.py,sha256=wxaEV6dTNiOqW_n2UHX0De6mMXTKDXK6UNcMNI4Rogk,2373
77
78
  torchzero/modules/projections/fft.py,sha256=OpCcEM1-A2dgk1umwRsBsvK7ObiHtsBKlkkcw0IX83Q,2961
78
79
  torchzero/modules/projections/galore.py,sha256=c9CZ0kHxpKEoyfc_lnmeHOkNp55jCppb7onN5YmWnN8,242
79
- torchzero/modules/projections/projection.py,sha256=tvUBZ4XGY1GkOg6jrKS7FvpIpjUc2FJL_SMRpoROT1E,9330
80
+ torchzero/modules/projections/projection.py,sha256=aYufSD3ftRUqVScPmqxwEFgP1P8ioxM8z9eyzaL7d10,10147
80
81
  torchzero/modules/projections/structural.py,sha256=QaCGHmzHCXj46sM-XZ5XlYU9BnuRKI2ReR3LE8y2R4g,5740
81
82
  torchzero/modules/quasi_newton/__init__.py,sha256=0iOlX73PHj9lQS3_2cJ5lyCdas904MnFfIvR8Popvzw,402
82
83
  torchzero/modules/quasi_newton/cg.py,sha256=h-di1oKKP1tDoh-LogBRIRCp2UF9GA6XjEJPlX6xXf4,9322
83
- torchzero/modules/quasi_newton/lbfgs.py,sha256=jtO5ldbx66yUWv-20c-4mvq6HhCMuomCwJK8A8bjcYA,9168
84
- torchzero/modules/quasi_newton/lsr1.py,sha256=F_DtMQZfQSjmSLjnx4nw16AV7qCdNxT9ITQbfNFrPdM,5879
84
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=U7FKNqFTRdabf1_UUCCEO3JoDlOnWzGBhYNvVg138gg,9199
85
+ torchzero/modules/quasi_newton/lsr1.py,sha256=BuoztcRo0lm7WW3vKtDQcfKizF-9WPItOY_X9Ng1ZS8,6033
85
86
  torchzero/modules/quasi_newton/olbfgs.py,sha256=2YAOXlMnPGw22sNcIMH1hmggzAXQRbN59RSPUZNKUZY,8352
86
- torchzero/modules/quasi_newton/quasi_newton.py,sha256=jwQkzlnozIaxHW9kuDAAlME0YuQdrdZX9OZZoTmej4Q,17384
87
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=5FutzdBNpx6P8Qun9LjXd-rsy2nY2bkpQ0z0cLPnbJo,17373
87
88
  torchzero/modules/quasi_newton/experimental/__init__.py,sha256=3qpZGgdsx6wpoafWaNWx-eamRl1FuxVCWQZq8Y7Cl98,39
88
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=PlyuIH2pFazIR89OGTrZESt752GkbArh_Zb8mtVCOi0,10731
89
- torchzero/modules/second_order/__init__.py,sha256=5lRmwIU53eRc1owpOZ5FMDc7u1Z48I3PDc0NyCBaJNM,113
89
+ torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=ec6JKYX89xA_UlY9VrMB3hBjDyNKwkalS_4JQGA1qOY,10762
90
+ torchzero/modules/second_order/__init__.py,sha256=jolCGaIVkID9hpxgx0Tc22wgjVlwuWekWjKTMe5jKXw,114
90
91
  torchzero/modules/second_order/newton.py,sha256=XNhscAuWwxOUwps3sUrxc2ExgkNFbilnAdszrCvQxFg,5845
91
- torchzero/modules/second_order/newton_cg.py,sha256=lUVn4-ZoW3qAxqEy8i7yz_aN7sZDoQChd-A_Ubrz-Ag,2871
92
+ torchzero/modules/second_order/newton_cg.py,sha256=stVySgo7tmvntd-tuAzThzpWmZzfTnmn8ISQa5Oi4yw,2872
92
93
  torchzero/modules/second_order/nystrom.py,sha256=ZyCWrde-_-Ednj46jafuvBOzG3nC-3cPYGr-HytZbsE,6073
93
94
  torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
94
95
  torchzero/modules/smoothing/gaussian.py,sha256=YlT_G4MqAVkiWG56RHAwgt5SSPISpvQZQbSLh8mhF3I,6153
@@ -116,13 +117,14 @@ torchzero/utils/python_tools.py,sha256=RFBqNj8w52dpJ983pUPPDbg2x1MX_-SsBnBMffWGG
116
117
  torchzero/utils/tensorlist.py,sha256=qSbiliVo1euFAksdHHHRbPUdYYxfkw1dvhpXj71wGy0,53162
117
118
  torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
118
119
  torchzero/utils/linalg/__init__.py,sha256=Dzbho3_z7JDdKzYD-QdLArg0ZEoC2BVGdlE3JoAnXHQ,272
120
+ torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
119
121
  torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
120
122
  torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
121
123
  torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
122
- torchzero/utils/linalg/solve.py,sha256=hN450ONzAirYOvWF2g0E0Wy2n1bCw4X-KXWi6p4jvDM,5136
124
+ torchzero/utils/linalg/solve.py,sha256=P0PMi0zro3G3Rd0X-JeoLk7tqYDB0js0aB4bpQ0OABU,5235
123
125
  torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
124
- torchzero-0.3.5.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
125
- torchzero-0.3.5.dist-info/METADATA,sha256=ZgqGz-rYGTWgbhM0K6CwIocmsmuKtzjdc-Y6nGibDZA,13944
126
- torchzero-0.3.5.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
127
- torchzero-0.3.5.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
128
- torchzero-0.3.5.dist-info/RECORD,,
126
+ torchzero-0.3.8.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
127
+ torchzero-0.3.8.dist-info/METADATA,sha256=vj5aue0pVG8xNStpOEvPfln422K5fpV3BKF-H7ZlhRQ,13941
128
+ torchzero-0.3.8.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
129
+ torchzero-0.3.8.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
130
+ torchzero-0.3.8.dist-info/RECORD,,