torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from . import core, optim, utils
2
- from .core import Modular
2
+ from .core import Optimizer
3
3
  from .utils.compile import enable_compilation
4
4
  from . import modules as m
@@ -3,6 +3,6 @@ from .module import Chainable, Module
3
3
  from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
4
4
 
5
5
  # order is important to avoid circular imports
6
- from .modular import Modular
6
+ from .modular import Optimizer
7
7
  from .functional import apply, step, step_tensors, update
8
8
  from .chain import Chain, maybe_chain
@@ -96,7 +96,7 @@ def step_tensors(
96
96
  objective.updates = list(tensors)
97
97
 
98
98
  # step with modules
99
- # this won't update parameters in-place because objective.Modular is None
99
+ # this won't update parameters in-place because objective.Optimizer is None
100
100
  objective = _chain_step(objective, modules)
101
101
 
102
102
  # return updates
torchzero/core/modular.py CHANGED
@@ -15,7 +15,7 @@ from .objective import Objective
15
15
  class _EvalCounterClosure:
16
16
  """keeps track of how many times closure has been evaluated, and sets closure return"""
17
17
  __slots__ = ("modular", "closure")
18
- def __init__(self, modular: "Modular", closure):
18
+ def __init__(self, modular: "Optimizer", closure):
19
19
  self.modular = modular
20
20
  self.closure = closure
21
21
 
@@ -46,9 +46,9 @@ def flatten_modules(*modules: Chainable) -> list[Module]:
46
46
  return flat
47
47
 
48
48
 
49
- # have to inherit from Modular to support lr schedulers
49
+ # have to inherit from Optimizer to support lr schedulers
50
50
  # although Accelerate doesn't work due to converting param_groups to a dict
51
- class Modular(torch.optim.Optimizer):
51
+ class Optimizer(torch.optim.Optimizer):
52
52
  """Chains multiple modules into an optimizer.
53
53
 
54
54
  Args:
@@ -62,7 +62,7 @@ class Modular(torch.optim.Optimizer):
62
62
  param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
63
63
 
64
64
  def __init__(self, params: Params | torch.nn.Module, *modules: Module):
65
- if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
65
+ if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Optimizer`")
66
66
  self.model: torch.nn.Module | None = None
67
67
  """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
68
68
  if isinstance(params, torch.nn.Module):
@@ -229,5 +229,5 @@ class Modular(torch.optim.Optimizer):
229
229
  return self._closure_return
230
230
 
231
231
  def __repr__(self):
232
- return f'Modular({", ".join(str(m) for m in self.modules)})'
232
+ return f'Optimizer({", ".join(str(m) for m in self.modules)})'
233
233
 
torchzero/core/module.py CHANGED
@@ -35,7 +35,7 @@ class Module(ABC):
35
35
 
36
36
  # settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
37
37
  # 0 - this module specific per-parameter setting overrides set via `set_param_groups` - highest priority
38
- # 1 - global per-parameter setting overrides in param_groups passed to Modular - medium priority
38
+ # 1 - global per-parameter setting overrides in param_groups passed to Optimizer - medium priority
39
39
  # 2 - `defaults` - lowest priority
40
40
  self.settings: defaultdict[torch.Tensor, ChainMap[str, Any]] = defaultdict(lambda: ChainMap({}, {}, self.defaults))
41
41
  """per-parameter settings."""
@@ -273,7 +273,7 @@ class Module(ABC):
273
273
  return state_dict
274
274
 
275
275
  def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
276
- """loads state_dict, ``id_to_tensor`` is passed by ``Modular``"""
276
+ """loads state_dict, ``id_to_tensor`` is passed by ``Optimizer``"""
277
277
  # load state
278
278
  state = state_dict['state']
279
279
  self.state.clear()
@@ -20,7 +20,7 @@ from ..utils.derivatives import (
20
20
  from ..utils.thoad_tools import thoad_derivatives, thoad_single_tensor, lazy_thoad
21
21
 
22
22
  if TYPE_CHECKING:
23
- from .modular import Modular
23
+ from .modular import Optimizer
24
24
  from .module import Module
25
25
 
26
26
  def _closure_backward(closure, params, backward, retain_graph, create_graph):
@@ -135,13 +135,13 @@ class Objective:
135
135
  model (torch.nn.Module | None, optional):
136
136
  ``torch.nn.Module`` object, needed for a few modules that require access to the model. Defaults to None.
137
137
  current_step (int, optional):
138
- number of times ``Modular.step()`` has been called, starting at 0. Defaults to 0.
138
+ number of times ``Optimizer.step()`` has been called, starting at 0. Defaults to 0.
139
139
  parent (Objective | None, optional):
140
140
  parent ``Objective`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
141
141
  Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
142
142
  e.g. when projecting. Defaults to None.
143
- modular (Modular | None, optional):
144
- Top-level ``Modular`` optimizer. Defaults to None.
143
+ modular (Optimizer | None, optional):
144
+ Top-level ``Optimizer`` optimizer. Defaults to None.
145
145
  storage (dict | None, optional):
146
146
  additional kwargs passed to ``step`` to control some module-specific behavior. Defaults to None.
147
147
 
@@ -154,7 +154,7 @@ class Objective:
154
154
  model: torch.nn.Module | None = None,
155
155
  current_step: int = 0,
156
156
  parent: "Objective | None" = None,
157
- modular: "Modular | None" = None,
157
+ modular: "Optimizer | None" = None,
158
158
  storage: dict | None = None,
159
159
  ):
160
160
  self.params: list[torch.Tensor] = list(params)
@@ -175,8 +175,8 @@ class Objective:
175
175
  Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
176
176
  e.g. when projecting."""
177
177
 
178
- self.modular: "Modular | None" = modular
179
- """Top-level ``Modular`` optimizer, ``None`` if it wasn't specified."""
178
+ self.modular: "Optimizer | None" = modular
179
+ """Top-level ``Optimizer`` optimizer, ``None`` if it wasn't specified."""
180
180
 
181
181
  self.updates: list[torch.Tensor] | None = None
182
182
  """
@@ -222,7 +222,7 @@ class Objective:
222
222
  # """Storage for any other data, such as hessian estimates, etc."""
223
223
 
224
224
  self.attrs: dict = {}
225
- """attributes, ``Modular.attrs`` is updated with this after each step.
225
+ """attributes, ``Optimizer.attrs`` is updated with this after each step.
226
226
  This attribute should always be modified in-place"""
227
227
 
228
228
  if storage is None: storage = {}
@@ -231,7 +231,7 @@ class Objective:
231
231
  This attribute should always be modified in-place"""
232
232
 
233
233
  self.should_terminate: bool | None = None
234
- """termination criteria, ``Modular.should_terminate`` is set to this after each step if not ``None``"""
234
+ """termination criteria, ``Optimizer.should_terminate`` is set to this after each step if not ``None``"""
235
235
 
236
236
  self.temp: Any = cast(Any, None)
237
237
  """temporary storage, ``Module.update`` can set this and ``Module.apply`` access via ``objective.poptemp()``.
@@ -756,7 +756,7 @@ class Objective:
756
756
  if g_list is not None and self.grads is None:
757
757
  self.grads = list(g_list)
758
758
 
759
- return f, g_list, H
759
+ return f, g_list, H.detach()
760
760
 
761
761
  @torch.no_grad
762
762
  def derivatives(self, order: int, at_x0: bool, method:DerivativesMethod="batched_autograd"):
@@ -233,7 +233,7 @@ class TensorTransform(Transform):
233
233
  if self._uses_grad: grads = objective.get_grads()
234
234
  else: grads = None # better explicitly set to None rather than objective.grads because it shouldn't be used
235
235
 
236
- if self._uses_loss: loss = objective.get_loss(backward=False)
236
+ if self._uses_loss: loss = objective.get_loss(backward=True)
237
237
  else: loss = None
238
238
 
239
239
  return grads, loss
@@ -3,8 +3,9 @@ from . import linear_operator
3
3
  from .matrix_power import (
4
4
  matrix_power_eigh,
5
5
  matrix_power_svd,
6
+ MatrixPowerMethod,
6
7
  )
7
- from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize
8
+ from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize,OrthogonalizeMethod
8
9
  from .qr import qr_householder
9
10
  from .solve import cg, nystrom_sketch_and_solve, nystrom_pcg
10
- from .eigh import nystrom_approximation
11
+ from .eigh import nystrom_approximation, regularize_eigh
torchzero/linalg/eigh.py CHANGED
@@ -1,7 +1,11 @@
1
1
  from collections.abc import Callable
2
+
2
3
  import torch
3
- from .linalg_utils import mm
4
4
 
5
+ from . import torch_linalg
6
+ from .linalg_utils import mm
7
+ from .orthogonalize import OrthogonalizeMethod, orthogonalize
8
+ from .svd import tall_reduced_svd_via_eigh
5
9
 
6
10
 
7
11
  # https://arxiv.org/pdf/2110.02820
@@ -11,6 +15,8 @@ def nystrom_approximation(
11
15
  ndim: int,
12
16
  rank: int,
13
17
  device,
18
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
19
+ eigv_tol: float = 0,
14
20
  dtype = torch.float32,
15
21
  generator = None,
16
22
  ) -> tuple[torch.Tensor, torch.Tensor]:
@@ -20,7 +26,7 @@ def nystrom_approximation(
20
26
  A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
21
27
  # basis
22
28
  O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
23
- O, _ = torch.linalg.qr(O) # Thin QR decomposition # pylint:disable=not-callable
29
+ O = orthogonalize(O, method=orthogonalize_method) # Thin QR decomposition # pylint:disable=not-callable
24
30
 
25
31
  # Y = AΩ
26
32
  AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
@@ -29,6 +35,219 @@ def nystrom_approximation(
29
35
  Yv = AO + v*O # Shift for stability
30
36
  C = torch.linalg.cholesky_ex(O.mT @ Yv)[0] # pylint:disable=not-callable
31
37
  B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
32
- Q, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
33
- L = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
38
+
39
+ # Q, S, _ = torch_linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
40
+ # B is (ndim, rank) so we can use eigendecomp of (rank, rank)
41
+ Q, S = tall_reduced_svd_via_eigh(B, tol=eigv_tol, retry_float64=True)
42
+
43
+ L = S.pow(2) - v
44
+ return L, Q
45
+
46
+
47
+ def regularize_eigh(
48
+ L: torch.Tensor,
49
+ Q: torch.Tensor,
50
+ truncate: int | None = None,
51
+ tol: float | None = None,
52
+ damping: float = 0,
53
+ rdamping: float = 0,
54
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
55
+ """Applies regularization to eigendecomposition. Returns ``(L, Q)``.
56
+
57
+ Args:
58
+ L (torch.Tensor): eigenvalues, shape ``(rank,)``.
59
+ Q (torch.Tensor): eigenvectors, shape ``(n, rank)``.
60
+ truncate (int | None, optional):
61
+ keeps top ``truncate`` eigenvalues. Defaults to None.
62
+ tol (float | None, optional):
63
+ all eigenvalues smaller than largest eigenvalue times ``tol`` are removed. Defaults to None.
64
+ damping (float | None, optional): scalar added to eigenvalues. Defaults to 0.
65
+ rdamping (float | None, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
66
+ """
67
+ # remove non-finite eigenvalues
68
+ finite = L.isfinite()
69
+ if finite.any():
70
+ L = L[finite]
71
+ Q = Q[:, finite]
72
+ else:
73
+ return None, None
74
+
75
+ # largest finite!!! eigval
76
+ L_max = L[-1] # L is sorted in ascending order
77
+
78
+ # remove small eigenvalues relative to largest
79
+ if tol is not None:
80
+ indices = L > tol * L_max
81
+ L = L[indices]
82
+ Q = Q[:, indices]
83
+
84
+ # truncate to rank (L is ordered in ascending order)
85
+ if truncate is not None:
86
+ L = L[-truncate:]
87
+ Q = Q[:, -truncate:]
88
+
89
+ # damping
90
+ d = damping + rdamping * L_max
91
+ if d != 0:
92
+ L += d
93
+
34
94
  return L, Q
95
+
96
+ def eigh_plus_uuT(
97
+ L: torch.Tensor,
98
+ Q: torch.Tensor,
99
+ u: torch.Tensor,
100
+ alpha: float = 1,
101
+ tol: float | None = None,
102
+ retry_float64: bool = False,
103
+ ) -> tuple[torch.Tensor, torch.Tensor]:
104
+ """
105
+ compute eigendecomposition of Q L Q^T + alpha * (u u^T) where Q is ``(m, rank)`` and L is ``(rank, )`` and u is ``(m, )``
106
+ """
107
+ if tol is None: tol = torch.finfo(Q.dtype).eps
108
+ z = Q.T @ u # (rank,)
109
+
110
+ # component of u orthogonal to the column space of Q
111
+ res = u - Q @ z # (m,)
112
+ beta = torch.linalg.vector_norm(res) # pylint:disable=not-callable
113
+
114
+ if beta < tol:
115
+ # u is already in the column space of Q
116
+ B = L.diag_embed().add_(z.outer(z), alpha=alpha) # (rank, rank)
117
+ L_prime, S = torch_linalg.eigh(B, retry_float64=retry_float64)
118
+ Q_prime = Q @ S
119
+ return L_prime, Q_prime
120
+
121
+ # normalize the orthogonal component to get a new orthonormal vector
122
+ v = res / beta # (m, )
123
+
124
+ # project and compute new eigendecomposition
125
+ D_diag = torch.cat([L, torch.tensor([0.0], device=Q.device, dtype=Q.dtype)])
126
+ w = torch.cat([z, beta.unsqueeze(0)]) # Shape: (rank+1,)
127
+ B = D_diag.diag_embed().add_(w.outer(w), alpha=alpha)
128
+
129
+ L_prime, S = torch_linalg.eigh(B, retry_float64=retry_float64)
130
+
131
+ # unproject and sort
132
+ basis = torch.cat([Q, v.unsqueeze(-1)], dim=1) # (m, rank+1)
133
+ Q_prime = basis @ S # (m, rank+1)
134
+
135
+ idx = torch.argsort(L_prime)
136
+ L_prime = L_prime[idx]
137
+ Q_prime = Q_prime[:, idx]
138
+
139
+ return L_prime, Q_prime
140
+
141
+ def eigh_plus_UUT(
142
+ L: torch.Tensor,
143
+ Q: torch.Tensor,
144
+ U: torch.Tensor,
145
+ alpha: float = 1,
146
+ tol = None,
147
+ retry_float64: bool = False,
148
+ ):
149
+ """
150
+ compute eigendecomposition of Q L Q^T + alpha * (U U^T), where Q is ``(m, rank)`` and L is ``(rank, )``,
151
+ U is ``(m, k)`` where k is rank of correction
152
+ """
153
+ if U.size(1) == 1:
154
+ return eigh_plus_uuT(L, Q, U[:,0], alpha=alpha, tol=tol, retry_float64=retry_float64)
155
+
156
+ if tol is None: tol = torch.finfo(Q.dtype).eps
157
+ m, r = Q.shape
158
+
159
+ Z = Q.T @ U # (r, k)
160
+ U_res = U - Q @ Z # (m, k)
161
+
162
+ # find cols of U not in col space of Q
163
+ res_norms = torch.linalg.vector_norm(U_res, dim=0) # pylint:disable=not-callable
164
+ new_indices = torch.where(res_norms > tol)[0]
165
+ k_prime = len(new_indices)
166
+
167
+ if k_prime == 0:
168
+ # all cols are in Q
169
+ B = Q
170
+ C = Z # (r x k)
171
+ r_new = r
172
+ else:
173
+ # orthonormalize directions that aren't in Q
174
+ U_new = U_res[:, new_indices]
175
+ Q_u, _ = torch_linalg.qr(U_new, mode='reduced', retry_float64=retry_float64)
176
+ B = torch.hstack([Q, Q_u])
177
+ C = torch.vstack([Z, Q_u.T @ U])
178
+ r_new = r + k_prime
179
+
180
+
181
+ # project and compute new eigendecomposition
182
+ A_proj = torch.zeros((r_new, r_new), device=Q.device, dtype=Q.dtype)
183
+ A_proj[:r, :r] = L.diag_embed()
184
+ A_proj.addmm_(C, C.T, alpha=alpha)
185
+
186
+ L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
187
+
188
+ # unproject and sort
189
+ Q_prime = B @ S
190
+ idx = torch.argsort(L_prime)
191
+ L_prime = L_prime[idx]
192
+ Q_prime = Q_prime[:, idx]
193
+
194
+ return L_prime, Q_prime
195
+
196
+
197
+ def eigh_plus_UVT_symmetrize(
198
+ Q: torch.Tensor,
199
+ L: torch.Tensor,
200
+ U: torch.Tensor,
201
+ V: torch.Tensor,
202
+ alpha: float,
203
+ retry_float64: bool = False,
204
+
205
+ ):
206
+ """
207
+ Q is ``(m, rank)``; L is ``(rank, )``; U and V are the low rank correction such that U V^T is ``(m, m)``.
208
+
209
+ This computes eigendecomposition of A, where
210
+
211
+ ``M = Q diag(L) Q^T + alpha * (U V^T)``;
212
+
213
+ ``A = (M + M^T) / 2``
214
+ """
215
+ m, rank = Q.shape
216
+ _, k = V.shape
217
+
218
+ # project U and V out of the Q subspace via Gram-schmidt
219
+ Q_T_U = Q.T @ U
220
+ U_perp = U - Q @ Q_T_U
221
+
222
+ Q_T_V = Q.T @ V
223
+ V_perp = V - Q @ Q_T_V
224
+
225
+ R = torch.hstack([U_perp, V_perp])
226
+ Q_perp, _ = torch_linalg.qr(R, retry_float64=retry_float64)
227
+
228
+ Q_B = torch.hstack([Q, Q_perp])
229
+ r_B = Q_B.shape[1]
230
+
231
+ # project, symmetrize and compute new eigendecomposition
232
+ A_proj = torch.zeros((r_B, r_B), device=Q.device, dtype=Q.dtype)
233
+ A_proj[:rank, :rank] = L.diag_embed()
234
+
235
+ Q_perp_T_U = Q_perp.T @ U
236
+ Q_B_T_U = torch.vstack([Q_T_U, Q_perp_T_U])
237
+
238
+ Q_perp_T_V = Q_perp.T @ V
239
+ Q_B_T_V = torch.vstack([Q_T_V, Q_perp_T_V])
240
+
241
+ update_proj = Q_B_T_U @ Q_B_T_V.T + Q_B_T_V @ Q_B_T_U.T
242
+ A_proj.add_(update_proj, alpha=alpha/2)
243
+
244
+ L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
245
+
246
+ # unproject and sort
247
+ Q_prime = Q_B @ S
248
+
249
+ idx = torch.argsort(L_prime)
250
+ L_prime = L_prime[idx]
251
+ Q_prime = Q_prime[:, idx]
252
+
253
+ return L_prime, Q_prime
@@ -1,4 +1,5 @@
1
1
  from typing import Literal
2
+
2
3
  import torch
3
4
 
4
5
  from ..utils.compile import allow_compile
@@ -49,9 +50,6 @@ def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Ten
49
50
 
50
51
  return X.to(G.dtype)
51
52
 
52
- # code from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
53
- # Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
54
- # Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
55
53
  def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
56
54
  """
57
55
  Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
@@ -87,7 +85,7 @@ def orthogonalize_via_qr(A: torch.Tensor):
87
85
  return Q
88
86
 
89
87
  OrthogonalizeMethod = Literal["newtonschulz", "svd", "qr"]
90
- def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod = "newtonschulz") -> torch.Tensor:
88
+ def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod) -> torch.Tensor:
91
89
  if method == "newtonschulz": return zeropower_via_newtonschulz5(A)
92
90
  if method == "svd": return zeropower_via_svd(A)
93
91
  if method == "qr": return orthogonalize_via_qr(A)
torchzero/linalg/qr.py CHANGED
@@ -2,6 +2,18 @@ from typing import Literal
2
2
  import torch
3
3
  from ..utils.compile import allow_compile
4
4
 
5
+
6
+ # super slow
7
+ # def cholesky_qr(A):
8
+ # """QR of (m, n) A via cholesky of (n, n) matrix"""
9
+ # AtA = A.T @ A
10
+
11
+ # L, _ = torch.linalg.cholesky_ex(AtA) # pylint:disable=not-callable
12
+ # R = L.T
13
+
14
+ # Q = torch.linalg.solve_triangular(R.T, A.T, upper=False).T # pylint:disable=not-callable
15
+ # return Q, R
16
+
5
17
  # reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
6
18
  @allow_compile
7
19
  def _get_w_tau(R: torch.Tensor, i: int, eps: float):
torchzero/linalg/solve.py CHANGED
@@ -25,15 +25,13 @@ def _make_A_mv_reg(A_mv: Callable, reg):
25
25
 
26
26
  def _identity(x): return x
27
27
 
28
- # TODO this is used in NystromSketchAndSolve
29
- # I need to add alternative to it where it just shifts eigenvalues by reg and uses their reciprocal
30
28
  def nystrom_sketch_and_solve(
31
29
  L: torch.Tensor,
32
30
  Q: torch.Tensor,
33
31
  b: torch.Tensor,
34
32
  reg: float = 1e-3,
35
33
  ) -> torch.Tensor:
36
- """Solves (Q diag(L) Q.T + reg*I)x = b. Becomes super unstable with reg smaller than like 1e-5.
34
+ """Solves ``(Q diag(L) Q.T + reg*I)x = b``. Becomes super unstable with reg smaller than like 1e-5.
37
35
 
38
36
  Args:
39
37
  L (torch.Tensor): eigenvalues, like from ``nystrom_approximation``
torchzero/linalg/svd.py CHANGED
@@ -1,20 +1,47 @@
1
- # import torch
2
-
3
- # # projected svd
4
- # # adapted from https://github.com/smortezavi/Randomized_SVD_GPU
5
- # def randomized_svd(M: torch.Tensor, k: int, driver=None):
6
- # *_, m, n = M.shape
7
- # transpose = False
8
- # if m < n:
9
- # transpose = True
10
- # M = M.mT
11
- # m,n = n,m
12
-
13
- # rand_matrix = torch.randn(size=(n, k), device=M.device, dtype=M.dtype)
14
- # Q, _ = torch.linalg.qr(M @ rand_matrix, mode='reduced') # pylint:disable=not-callable
15
- # smaller_matrix = Q.mT @ M
16
- # U_hat, s, V = torch.linalg.svd(smaller_matrix, driver=driver, full_matrices=False) # pylint:disable=not-callable
17
- # U = Q @ U_hat
18
-
19
- # if transpose: return V.mT, s, U.mT
20
- # return U, s, V
1
+ import torch
2
+
3
+ from . import torch_linalg
4
+
5
+
6
+ def tall_reduced_svd_via_eigh(A: torch.Tensor, tol: float = 0, retry_float64:bool=False):
7
+ """
8
+ Given a tall matrix A of size (m, n), computes U and S from the reduced SVD(A)
9
+ using the eigendecomposition of (n, n) matrix which is faster than direct SVD when m >= n.
10
+
11
+ This truncates small singular values that would causes nans,
12
+ so the returned U and S can have reduced dimension ``k <= n``.
13
+
14
+ Returns U of size ``(m, k)`` and S of size ``(k, )``.
15
+
16
+ Args:
17
+ A (torch.Tensor): A tall matrix of size (m, n) with m >= n.
18
+ tol (float): Tolerance for truncating small singular values. Singular values
19
+ less than ``tol * max_singular_value`` will be discarded.
20
+
21
+
22
+ """
23
+ # if m < n, A.T A will be low rank and we can't use eigh
24
+ m, n = A.size()
25
+ if m < n:
26
+ U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
27
+ return U, S
28
+
29
+ M = A.mH @ A # n,n
30
+
31
+ try:
32
+ L, Q = torch_linalg.eigh(M, retry_float64=retry_float64)
33
+ except torch.linalg.LinAlgError:
34
+ U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
35
+ return U, S
36
+
37
+ L = torch.flip(L, dims=[-1])
38
+ Q = torch.flip(Q, dims=[-1])
39
+
40
+ indices = L > tol * L[0] # L[0] is the max eigenvalue
41
+ L = L[indices]
42
+ Q = Q[:, indices]
43
+
44
+ S = L.sqrt()
45
+ U = (A @ Q) / S
46
+
47
+ return U, S
@@ -1,4 +1,6 @@
1
1
  from . import experimental
2
+ from .adaptive import *
3
+ from .adaptive import lre_optimizers as lre
2
4
  from .clipping import *
3
5
  from .conjugate_gradient import *
4
6
  from .grad_approximation import *
@@ -7,9 +9,9 @@ from .line_search import *
7
9
  from .misc import *
8
10
  from .momentum import *
9
11
  from .ops import *
10
- from .adaptive import *
11
12
  from .projections import *
12
13
  from .quasi_newton import *
14
+ from .restarts import *
13
15
  from .second_order import *
14
16
  from .smoothing import *
15
17
  from .step_size import *
@@ -18,5 +20,4 @@ from .trust_region import *
18
20
  from .variance_reduction import *
19
21
  from .weight_decay import *
20
22
  from .wrappers import *
21
- from .restarts import *
22
- from .zeroth_order import *
23
+ from .zeroth_order import *
@@ -1,4 +1,5 @@
1
- from .adagrad import Adagrad, FullMatrixAdagrad, AdagradNorm
1
+ from . import lre_optimizers
2
+ from .adagrad import Adagrad, AdagradNorm, FullMatrixAdagrad
2
3
 
3
4
  # from .curveball import CurveBall
4
5
  # from .spectral import SpectralPreconditioner
@@ -8,14 +9,21 @@ from .adan import Adan
8
9
  from .adaptive_heavyball import AdaptiveHeavyBall
9
10
  from .aegd import AEGD
10
11
  from .esgd import ESGD
11
- from .lmadagrad import LMAdagrad
12
12
  from .lion import Lion
13
+ from .ggt import GGT
13
14
  from .mars import MARSCorrection
14
15
  from .matrix_momentum import MatrixMomentum
15
- from .msam import MSAMMomentum, MSAM
16
+ from .msam import MSAM, MSAMMomentum
16
17
  from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
17
18
  from .natural_gradient import NaturalGradient
18
19
  from .orthograd import OrthoGrad, orthograd_
20
+ from .psgd import (
21
+ PSGDDenseNewton,
22
+ PSGDKronNewton,
23
+ PSGDKronWhiten,
24
+ PSGDLRANewton,
25
+ PSGDLRAWhiten,
26
+ )
19
27
  from .rmsprop import RMSprop
20
28
  from .rprop import (
21
29
  BacktrackOnSignChange,