torchzero 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. torchzero/__init__.py +3 -1
  2. torchzero/_minimize/__init__.py +0 -0
  3. torchzero/_minimize/methods.py +95 -0
  4. torchzero/_minimize/minimize.py +518 -0
  5. torchzero/core/__init__.py +5 -5
  6. torchzero/core/chain.py +2 -1
  7. torchzero/core/functional.py +2 -1
  8. torchzero/core/module.py +75 -4
  9. torchzero/core/transform.py +6 -5
  10. torchzero/linalg/eigh.py +116 -68
  11. torchzero/linalg/linear_operator.py +1 -0
  12. torchzero/linalg/orthogonalize.py +60 -5
  13. torchzero/linalg/sketch.py +39 -0
  14. torchzero/modules/__init__.py +1 -0
  15. torchzero/modules/adaptive/adagrad.py +2 -0
  16. torchzero/modules/adaptive/adam.py +5 -1
  17. torchzero/modules/adaptive/adan.py +3 -0
  18. torchzero/modules/adaptive/ggt.py +20 -18
  19. torchzero/modules/adaptive/lion.py +3 -1
  20. torchzero/modules/adaptive/mars.py +6 -5
  21. torchzero/modules/adaptive/msam.py +3 -0
  22. torchzero/modules/adaptive/rmsprop.py +2 -0
  23. torchzero/modules/adaptive/rprop.py +9 -7
  24. torchzero/modules/adaptive/shampoo.py +9 -1
  25. torchzero/modules/adaptive/soap.py +32 -29
  26. torchzero/modules/basis/__init__.py +2 -0
  27. torchzero/modules/basis/ggt_basis.py +199 -0
  28. torchzero/modules/basis/soap_basis.py +254 -0
  29. torchzero/modules/clipping/ema_clipping.py +32 -27
  30. torchzero/modules/clipping/growth_clipping.py +1 -0
  31. torchzero/modules/experimental/__init__.py +1 -6
  32. torchzero/modules/experimental/coordinate_momentum.py +2 -0
  33. torchzero/modules/experimental/cubic_adam.py +4 -0
  34. torchzero/modules/grad_approximation/__init__.py +3 -2
  35. torchzero/modules/least_squares/gn.py +6 -0
  36. torchzero/modules/misc/gradient_accumulation.py +1 -0
  37. torchzero/modules/misc/misc.py +6 -0
  38. torchzero/modules/momentum/averaging.py +6 -0
  39. torchzero/modules/momentum/momentum.py +4 -0
  40. torchzero/modules/ops/__init__.py +0 -1
  41. torchzero/modules/ops/accumulate.py +4 -0
  42. torchzero/modules/ops/higher_level.py +6 -1
  43. torchzero/modules/second_order/inm.py +4 -0
  44. torchzero/modules/second_order/newton.py +11 -3
  45. torchzero/modules/second_order/newton_cg.py +7 -3
  46. torchzero/modules/second_order/nystrom.py +14 -19
  47. torchzero/modules/second_order/rsn.py +37 -6
  48. torchzero/modules/trust_region/trust_region.py +2 -1
  49. torchzero/utils/benchmarks/logistic.py +33 -18
  50. torchzero/utils/params.py +13 -1
  51. torchzero/utils/tensorlist.py +2 -2
  52. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/METADATA +1 -1
  53. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/RECORD +56 -53
  54. torchzero/modules/experimental/adanystrom.py +0 -258
  55. torchzero/modules/experimental/common_directions_whiten.py +0 -142
  56. torchzero/modules/experimental/eigen_sr1.py +0 -182
  57. torchzero/modules/experimental/eigengrad.py +0 -207
  58. /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
  59. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/WHEEL +0 -0
  60. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/top_level.txt +0 -0
torchzero/core/module.py CHANGED
@@ -2,18 +2,19 @@ import warnings
2
2
  from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
4
  from collections.abc import Callable, Iterable, Sequence
5
- from typing import Any, overload, TYPE_CHECKING
5
+ from typing import Any, overload, TYPE_CHECKING, Literal
6
6
 
7
7
  import torch
8
8
 
9
9
  from ..linalg.linear_operator import LinearOperator
10
10
  from ..utils.optimizer import Init, ListLike, get_state_vals
11
- from ..utils.params import Params, _make_param_groups
11
+ from ..utils.params import Params, _make_param_groups, _set_fake_params_, _empty_fake_param_storage_
12
12
  from .functional import step_tensors
13
13
 
14
14
  if TYPE_CHECKING:
15
15
  from .objective import Objective
16
16
 
17
+ ProjectedBuffer = Literal["grad", "grad_sq", "grad_cu", "covariance", "inverse"]
17
18
 
18
19
  class Module(ABC):
19
20
  """Abstract base class for an optimizer modules.
@@ -52,6 +53,12 @@ class Module(ABC):
52
53
  self._overridden_keys = set()
53
54
  """tracks keys overridden with ``set_param_groups``, only used to not give a warning"""
54
55
 
56
+ self._projected_keys: defaultdict[ProjectedBuffer, set[str]] = defaultdict(set)
57
+ """tracks keys with gradient-like buffers, covariance-like buffers, etc for reprojecting"""
58
+
59
+ self._fake_params: dict[str, list[torch.Tensor]] = {}
60
+ """fake parameters for state keys and shape inference, key is name of child, value is list of fake parameters"""
61
+
55
62
 
56
63
  def set_param_groups(self, param_groups: Params):
57
64
  """Set custom parameter groups with per-parameter settings that this module will use."""
@@ -123,7 +130,9 @@ class Module(ABC):
123
130
  clone (bool):
124
131
  If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
125
132
  If ``key`` doesn't exist, ``tensors`` are always returned without cloning
126
- params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
133
+ params (Iterable[torch.Tensor] | None, optional):
134
+ pass None if ``tensors`` have different shape, it will create fake params from tensors
135
+ for state keys and shape inference. Defaults to None.
127
136
  grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
128
137
  loss (torch.Tensor | None, optional): loss. Defaults to None.
129
138
  closure (Callable | None, optional): closure. Defaults to None.
@@ -137,9 +146,26 @@ class Module(ABC):
137
146
  return tensors
138
147
 
139
148
  if clone: tensors = [t.clone() for t in tensors]
140
- return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
149
+
150
+ # set fake params to same storage as tensors so as to not use any extra memory
151
+ # while they still refer to same python objects, so they can be used
152
+ # as state keys and for shape inference when params aren't given.
153
+ fake = params is None
154
+ if fake:
155
+ if key not in self._fake_params:
156
+ self._fake_params[key] = [torch.empty_like(t) for t in tensors]
157
+ params = self._fake_params[key]
158
+ _set_fake_params_(params, tensors)
159
+
160
+ update = step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
141
161
  loss=loss, closure=closure, objective=objective)
142
162
 
163
+ # set fake params storage to empty
164
+ if fake:
165
+ _empty_fake_param_storage_(params)
166
+
167
+ return update
168
+
143
169
 
144
170
  def __repr__(self):
145
171
  s = self.__class__.__name__
@@ -322,6 +348,48 @@ class Module(ABC):
322
348
  self.global_state[key] = value
323
349
  return value
324
350
 
351
+ def get_child_projected_buffers(self, key: str, buff: ProjectedBuffer | Sequence[ProjectedBuffer], params:Sequence[torch.Tensor] | None = None) -> list[list[torch.Tensor]]:
352
+ """if params is None, assumes fake parameters"""
353
+ if isinstance(buff, str): buff = (buff, )
354
+
355
+ child = self.children[key]
356
+ child.on_get_projected_buffers()
357
+ if params is None:
358
+ params = self._fake_params[key]
359
+
360
+ vals = []
361
+ for b in buff:
362
+ for buff_key in child._projected_keys[b]:
363
+ state = child.state[params[0]]
364
+ if buff_key in state:
365
+ tensors = [child.state[p][buff_key] for p in params]
366
+ if isinstance(tensors[0], torch.Tensor):
367
+ vals.append(tensors)
368
+ else: # its usually a deque
369
+ assert isinstance(tensors[0], Sequence), type(tensors[0])
370
+ vals.extend(zip(*tensors))
371
+
372
+ elif buff_key in child.global_state:
373
+ val = child.global_state[buff_key]
374
+ if len(val) == 0: continue
375
+ if isinstance(val[0], torch.Tensor):
376
+ vals.append(val)
377
+ else:
378
+ assert isinstance(val[0], Sequence)
379
+ vals.extend(zip(*vals))
380
+
381
+ # recursively do this on children,
382
+ # note that if params are fake, children will have same fake params
383
+ # unless that child steps with something else. I don't think that is feasible to support it
384
+ for c in child.children:
385
+ vals.extend(child.get_child_projected_buffers(c, buff, params=params))
386
+
387
+ return vals
388
+
389
+ def add_projected_keys(self, buffer: ProjectedBuffer, *keys):
390
+ for k in keys: self._projected_keys[buffer].add(k)
391
+
392
+
325
393
  # ---------------------------- OVERRIDABLE METHODS --------------------------- #
326
394
  def update(self, objective:"Objective") -> None:
327
395
  """Updates internal state of this module. This should not modify ``objective.update``.
@@ -394,6 +462,9 @@ class Module(ABC):
394
462
  """
395
463
  for c in self.children.values(): c.reset_for_online()
396
464
 
465
+ def on_get_projected_buffers(self):
466
+ """runs before projected buffers are accessed"""
467
+
397
468
  def _extra_pack(self) -> dict:
398
469
  """extra information to store in ``state_dict`` of this optimizer.
399
470
  Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
@@ -1,12 +1,12 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from collections.abc import Mapping, Sequence
3
3
  from operator import itemgetter
4
- from typing import Any, final, cast, TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Any, cast, final
5
5
 
6
6
  import torch
7
7
 
8
+ from ..utils import safe_dict_update_, vec_to_tensors
8
9
  from .module import Module
9
- from ..utils import vec_to_tensors, safe_dict_update_
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from .chain import Chainable
@@ -31,7 +31,7 @@ class Transform(Module):
31
31
 
32
32
  self._objective = None
33
33
  if inner is not None:
34
- self.set_child("inner", inner)
34
+ self.set_child("__inner", inner)
35
35
 
36
36
  # settings shouldn't mutate, so they are typed as Sequence[Mapping]
37
37
  def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
@@ -70,8 +70,8 @@ class Transform(Module):
70
70
  def apply(self, objective: "Objective"):
71
71
 
72
72
  # inner step
73
- if "inner" in self.children:
74
- inner = self.children["inner"]
73
+ if "__inner" in self.children:
74
+ inner = self.children["__inner"]
75
75
  objective = inner.step(objective)
76
76
 
77
77
  # apply and return
@@ -128,6 +128,7 @@ class TensorTransform(Transform):
128
128
  self._uses_grad = uses_grad
129
129
  self._uses_loss = uses_loss
130
130
 
131
+
131
132
  # ------------------------------- single tensor ------------------------------ #
132
133
  def single_tensor_initialize(
133
134
  self,
torchzero/linalg/eigh.py CHANGED
@@ -10,30 +10,18 @@ from .svd import tall_reduced_svd_via_eigh
10
10
 
11
11
  # https://arxiv.org/pdf/2110.02820
12
12
  def nystrom_approximation(
13
- A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
14
- A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
15
- ndim: int,
16
- rank: int,
17
- device,
18
- orthogonalize_method: OrthogonalizeMethod = 'qr',
13
+ Omega: torch.Tensor,
14
+ AOmega: torch.Tensor,
19
15
  eigv_tol: float = 0,
20
- dtype = torch.float32,
21
- generator = None,
22
16
  ) -> tuple[torch.Tensor, torch.Tensor]:
23
17
  """Computes Nyström approximation to positive-semidefinite A factored as Q L Q^T (truncatd eigenvalue decomp),
24
18
  returns ``(L, Q)``.
25
19
 
26
20
  A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
27
- # basis
28
- O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
29
- O = orthogonalize(O, method=orthogonalize_method) # Thin QR decomposition # pylint:disable=not-callable
30
-
31
- # Y = AΩ
32
- AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
33
21
 
34
- v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(AO, ord='fro') # Compute shift # pylint:disable=not-callable
35
- Yv = AO + v*O # Shift for stability
36
- C = torch.linalg.cholesky_ex(O.mT @ Yv)[0] # pylint:disable=not-callable
22
+ v = torch.finfo(AOmega.dtype).eps * torch.linalg.matrix_norm(AOmega, ord='fro') # Compute shift # pylint:disable=not-callable
23
+ Yv = AOmega + v*Omega # Shift for stability
24
+ C = torch.linalg.cholesky_ex(Omega.mT @ Yv)[0] # pylint:disable=not-callable
37
25
  B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
38
26
 
39
27
  # Q, S, _ = torch_linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
@@ -138,26 +126,35 @@ def eigh_plus_uuT(
138
126
 
139
127
  return L_prime, Q_prime
140
128
 
141
- def eigh_plus_UUT(
129
+ def eigh_plus_UUt(
142
130
  L: torch.Tensor,
143
131
  Q: torch.Tensor,
144
132
  U: torch.Tensor,
145
- alpha: float = 1,
133
+ alpha: float | torch.Tensor = 1,
146
134
  tol = None,
147
- retry_float64: bool = False,
148
- ):
135
+ ortho_method: OrthogonalizeMethod = 'qr',
136
+ retry_float64=True,
137
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
149
138
  """
150
139
  compute eigendecomposition of Q L Q^T + alpha * (U U^T), where Q is ``(m, rank)`` and L is ``(rank, )``,
151
140
  U is ``(m, k)`` where k is rank of correction
141
+
142
+ returns ``(L, Q)``
152
143
  """
153
144
  if U.size(1) == 1:
154
- return eigh_plus_uuT(L, Q, U[:,0], alpha=alpha, tol=tol, retry_float64=retry_float64)
145
+ return eigh_plus_uuT(L, Q, U[:,0], alpha=float(alpha), tol=tol)
146
+
147
+ # make alpha shape (k, )
148
+ k = U.size(1)
149
+ if isinstance(alpha, torch.Tensor):
150
+ alpha = torch.broadcast_to(alpha, (k, ))
151
+ else:
152
+ alpha = torch.full((k,), float(alpha), device=U.device, dtype=U.dtype)
155
153
 
156
154
  if tol is None: tol = torch.finfo(Q.dtype).eps
157
155
  m, r = Q.shape
158
-
159
- Z = Q.T @ U # (r, k)
160
- U_res = U - Q @ Z # (m, k)
156
+ QtU = Q.T @ U # (r, k)
157
+ U_res = U - Q @ QtU # (m, k)
161
158
 
162
159
  # find cols of U not in col space of Q
163
160
  res_norms = torch.linalg.vector_norm(U_res, dim=0) # pylint:disable=not-callable
@@ -167,23 +164,26 @@ def eigh_plus_UUT(
167
164
  if k_prime == 0:
168
165
  # all cols are in Q
169
166
  B = Q
170
- C = Z # (r x k)
167
+ C = QtU # (r x k)
171
168
  r_new = r
172
169
  else:
173
170
  # orthonormalize directions that aren't in Q
174
171
  U_new = U_res[:, new_indices]
175
- Q_u, _ = torch_linalg.qr(U_new, mode='reduced', retry_float64=retry_float64)
172
+ Q_u = orthogonalize(U_new, method=ortho_method)
176
173
  B = torch.hstack([Q, Q_u])
177
- C = torch.vstack([Z, Q_u.T @ U])
174
+ C = torch.vstack([QtU, Q_u.T @ U_res])
178
175
  r_new = r + k_prime
179
176
 
180
-
181
177
  # project and compute new eigendecomposition
182
178
  A_proj = torch.zeros((r_new, r_new), device=Q.device, dtype=Q.dtype)
183
179
  A_proj[:r, :r] = L.diag_embed()
184
- A_proj.addmm_(C, C.T, alpha=alpha)
180
+ # A_proj += (C @ C.T).mul_(alpha)
181
+ A_proj.addmm_(C * alpha, C.T)
185
182
 
186
- L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
183
+ try:
184
+ L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
185
+ except torch.linalg.LinAlgError:
186
+ return None, None
187
187
 
188
188
  # unproject and sort
189
189
  Q_prime = B @ S
@@ -194,60 +194,108 @@ def eigh_plus_UUT(
194
194
  return L_prime, Q_prime
195
195
 
196
196
 
197
- def eigh_plus_UVT_symmetrize(
198
- Q: torch.Tensor,
197
+ def eigh_plus_UUt_mm(
198
+ # A1 = Q @ diag(L) @ Q.T
199
199
  L: torch.Tensor,
200
+ Q: torch.Tensor,
201
+
202
+ # A2 = U @ U.T
200
203
  U: torch.Tensor,
201
- V: torch.Tensor,
202
- alpha: float,
203
- retry_float64: bool = False,
204
204
 
205
- ):
205
+ # rhs
206
+ B: torch.Tensor,
207
+
208
+ # weights
209
+ w1: float,
210
+ w2: float | torch.Tensor,
211
+
212
+ ) -> torch.Tensor:
206
213
  """
207
- Q is ``(m, rank)``; L is ``(rank, )``; U and V are the low rank correction such that U V^T is ``(m, m)``.
214
+ Computes ``(w1 * (Q L Q^T) + (U diag(w2) U^T) @ B``,
208
215
 
209
- This computes eigendecomposition of A, where
216
+ Q is ``(m, rank)``, L is ``(rank, rank)``, U is ``(m, z)``, B is ``(m, k)``.
210
217
 
211
- ``M = Q diag(L) Q^T + alpha * (U V^T)``;
218
+ Returns ``(m, k)``
219
+ """
220
+ # sketch Q L Q^T
221
+ QtB = Q.T @ B # (rank, k)
222
+ LQtB = L.unsqueeze(1) * QtB # (rank, k)
223
+ sketch1 = Q @ LQtB # (m, k)
224
+
225
+ # skecth U U^T
226
+ UtB = U.T @ B # (z, k)
227
+ if isinstance(w2, torch.Tensor) and w2.numel() > 1: w2UtB = w2.unsqueeze(-1) * UtB
228
+ else: w2UtB = w2 * UtB
229
+ sketch2 = U @ w2UtB # (m, k)
230
+
231
+ return w1 * sketch1 + sketch2
212
232
 
213
- ``A = (M + M^T) / 2``
233
+
234
+ def randomized_eigh_plus_UUt(
235
+ L1: torch.Tensor,
236
+ Q1: torch.Tensor,
237
+ U: torch.Tensor,
238
+ w1: float,
239
+ w2: float | torch.Tensor,
240
+ oversampling_p: int,
241
+ rank: int,
242
+ eig_tol: float,
243
+ damping: float,
244
+ rdamping: float,
245
+ ortho_method: OrthogonalizeMethod = 'qr',
246
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
247
+ """
248
+ compute randomized eigendecomposition of w1 * Q L Q^T + w2 * (U U^T),
249
+ where Q is ``(m, rank)`` and L is ``(rank, )``,
250
+ U is ``(m, k)`` where k is rank of correction, returns ``(L, Q)``
214
251
  """
215
- m, rank = Q.shape
216
- _, k = V.shape
252
+ n = Q1.shape[0]
253
+ device = Q1.device
254
+ dtype = Q1.dtype
255
+ l = rank + oversampling_p
217
256
 
218
- # project U and V out of the Q subspace via Gram-schmidt
219
- Q_T_U = Q.T @ U
220
- U_perp = U - Q @ Q_T_U
257
+ # gaussian test matrix
258
+ Omega = torch.randn(n, l, device=device, dtype=dtype)
221
259
 
222
- Q_T_V = Q.T @ V
223
- V_perp = V - Q @ Q_T_V
260
+ # sketch
261
+ AOmega = eigh_plus_UUt_mm(L1, Q1, U, Omega, w1, w2)
262
+ Q = orthogonalize(AOmega, ortho_method)
224
263
 
225
- R = torch.hstack([U_perp, V_perp])
226
- Q_perp, _ = torch_linalg.qr(R, retry_float64=retry_float64)
264
+ AQ = eigh_plus_UUt_mm(L1, Q1, U, Q, w1, w2)
265
+ QtAQ = Q.T @ AQ
227
266
 
228
- Q_B = torch.hstack([Q, Q_perp])
229
- r_B = Q_B.shape[1]
267
+ W = (QtAQ + QtAQ.T) / 2.0
230
268
 
231
- # project, symmetrize and compute new eigendecomposition
232
- A_proj = torch.zeros((r_B, r_B), device=Q.device, dtype=Q.dtype)
233
- A_proj[:rank, :rank] = L.diag_embed()
269
+ # compute new L and Q
270
+ try:
271
+ L_prime, S = torch.linalg.eigh(W) # pylint:disable=not-callable
272
+ except torch.linalg.LinAlgError:
273
+ return L1, Q1
234
274
 
235
- Q_perp_T_U = Q_perp.T @ U
236
- Q_B_T_U = torch.vstack([Q_T_U, Q_perp_T_U])
275
+ L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
237
276
 
238
- Q_perp_T_V = Q_perp.T @ V
239
- Q_B_T_V = torch.vstack([Q_T_V, Q_perp_T_V])
277
+ if L_prime is None or S is None:
278
+ return L1, Q1
240
279
 
241
- update_proj = Q_B_T_U @ Q_B_T_V.T + Q_B_T_V @ Q_B_T_U.T
242
- A_proj.add_(update_proj, alpha=alpha/2)
280
+ return L_prime, Q @ S
243
281
 
244
- L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
245
282
 
246
- # unproject and sort
247
- Q_prime = Q_B @ S
283
+ def rank1_eigh(v: torch.Tensor):
284
+ """returns ``(L, Q)`` of ``(v v^T)``"""
285
+ vv = v.dot(v)
286
+ norm = vv.sqrt().clip(min=torch.finfo(vv.dtype).tiny * 2)
248
287
 
249
- idx = torch.argsort(L_prime)
250
- L_prime = L_prime[idx]
251
- Q_prime = Q_prime[:, idx]
288
+ L = vv.unsqueeze(0) # (rank, )
289
+ Q = v.unsqueeze(-1) / norm # (m, rank)
252
290
 
253
- return L_prime, Q_prime
291
+ return L, Q
292
+
293
+ def low_rank_eigh(U: torch.Tensor):
294
+ """returns ``(L, Q)`` of ``alpha * (U U^T)`` (from GGT)"""
295
+ M = U.T @ U
296
+ L, S = torch.linalg.eigh(M) # pylint:disable=not-callable
297
+
298
+ Q = U @ S
299
+ Q /= torch.sqrt(L).clip(min=torch.finfo(L.dtype).tiny * 2)
300
+
301
+ return L, Q
@@ -425,3 +425,4 @@ class Eigendecomposition(LinearOperator):
425
425
  def size(self):
426
426
  n = self.Q.size(0)
427
427
  return (n,n)
428
+
@@ -51,9 +51,6 @@ def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Ten
51
51
  return X.to(G.dtype)
52
52
 
53
53
  def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
54
- """
55
- Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
56
- """
57
54
  try:
58
55
  U, S, Vt = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
59
56
  except torch.linalg.LinAlgError:
@@ -84,9 +81,67 @@ def orthogonalize_via_qr(A: torch.Tensor):
84
81
 
85
82
  return Q
86
83
 
87
- OrthogonalizeMethod = Literal["newtonschulz", "svd", "qr"]
84
+ # CODE FROM https://github.com/HomebrewML/HeavyBall/blob/main/heavyball/utils.py:
85
+
86
+ ## Based on https://arxiv.org/pdf/2505.16932v3
87
+ # and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
88
+ # and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
89
+ #
90
+ # under the MIT License
91
+ # Coefficients are from https://arxiv.org/pdf/2505.16932v3
92
+ ABC_LIST: list[tuple[float, float, float]] = [
93
+ (8.28721201814563, -23.595886519098837, 17.300387312530933),
94
+ (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
95
+ (3.9486908534822946, -2.908902115962949, 0.5518191394370137),
96
+ (3.3184196573706015, -2.488488024314874, 0.51004894012372),
97
+ (2.300652019954817, -1.6689039845747493, 0.4188073119525673),
98
+ (1.891301407787398, -1.2679958271945868, 0.37680408948524835),
99
+ (1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
100
+ (1.875, -1.25, 0.375),
101
+ ]
102
+
103
+ # safety factor for numerical stability (but exclude last polynomial)
104
+ ABC_LIST_STABLE: list[tuple[float, float, float]] = [
105
+ (a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
106
+ ] + [ABC_LIST[-1]]
107
+
108
+
109
+ def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
110
+ """
111
+ Polar Express algorithm for the matrix sign function:
112
+ https://arxiv.org/abs/2505.16932
113
+ """
114
+ assert G.ndim >= 2
115
+ should_transpose: bool = G.size(-2) > G.size(-1)
116
+
117
+ x = G
118
+ if should_transpose:
119
+ x = x.mT
120
+
121
+ x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
122
+
123
+ for step in range(steps):
124
+ a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
125
+ s = x @ x.mT
126
+ # goal is to compute x = a x + b S x + c S^2 x
127
+ # we can break this up into: x = (a I + (b I + c S) S) x
128
+ y = c * s
129
+ y.diagonal(dim1=-2, dim2=-1).add_(b)
130
+ y = y @ s
131
+ y.diagonal(dim1=-2, dim2=-1).add_(a)
132
+ x = y @ x
133
+
134
+ if should_transpose:
135
+ x = x.mT
136
+ return x.float()
137
+
138
+
139
+ ###### END
140
+
141
+ OrthogonalizeMethod = Literal["newtonschulz", "ns5", "polar_express", "svd", "qr", "eigh"]
88
142
  def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod) -> torch.Tensor:
89
- if method == "newtonschulz": return zeropower_via_newtonschulz5(A)
143
+ if method in ("newtonschulz", "ns5"): return zeropower_via_newtonschulz5(A)
144
+ if method == "polar_express": return msign(A)
90
145
  if method == "svd": return zeropower_via_svd(A)
91
146
  if method == "qr": return orthogonalize_via_qr(A)
92
147
  if method == "eigh": return zeropower_via_eigh(A)
@@ -0,0 +1,39 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ from .orthogonalize import orthogonalize_via_qr
6
+ from .linear_operator import LinearOperator, Dense
7
+
8
+
9
+ class Permutation(LinearOperator):
10
+ def __init__(self, indices:torch.Tensor):
11
+ self.indices = indices
12
+ self.device = indices.device
13
+
14
+ def matvec(self, x):
15
+ return x[self.indices]
16
+
17
+ def matmat(self, X):
18
+ return Dense(X[:, self.indices])
19
+
20
+ def orthonormal_sketch(m, k, dtype, device, generator):
21
+ return orthogonalize_via_qr(torch.randn(m, k, dtype=dtype, device=device, generator=generator))
22
+
23
+ def rademacher_sketch(m, k, dtype, device, generator):
24
+ rademacher = torch.bernoulli(torch.full((m, k), 0.5, device=device, dtype=dtype), generator = generator).mul_(2).sub_(1)
25
+ return rademacher.mul_(1 / math.sqrt(m))
26
+
27
+ def row_sketch(m, k, dtype, device, generator):
28
+ weights = torch.ones(m, dtype=dtype, device=device)
29
+ indices = torch.multinomial(weights, k, replacement=False, generator=generator)
30
+
31
+ P = torch.zeros(m, k, dtype=dtype, device=device)
32
+ P[indices, range(k)] = 1
33
+ return P
34
+
35
+ def topk_rows_sketch(v: torch.Tensor, m, k, dtype, device):
36
+ _, indices = torch.topk(v, k)
37
+ P = torch.zeros(m, k, dtype=dtype, device=device)
38
+ P[indices, range(k)] = 1
39
+ return P
@@ -21,3 +21,4 @@ from .variance_reduction import *
21
21
  from .weight_decay import *
22
22
  from .wrappers import *
23
23
  from .zeroth_order import *
24
+ from .basis import *
@@ -40,6 +40,7 @@ class Adagrad(TensorTransform):
40
40
  super().__init__(defaults=defaults, inner=inner)
41
41
 
42
42
  self.set_child('accumulator', accumulator_tfm)
43
+ self.add_projected_keys("grad", "accumulator")
43
44
 
44
45
  @torch.no_grad
45
46
  def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
@@ -235,6 +236,7 @@ class FullMatrixAdagrad(TensorTransform):
235
236
  super().__init__(defaults=defaults, inner=inner, concat_params=concat_params)
236
237
 
237
238
  self.set_child("accumulator", accumulator_tfm)
239
+ self.add_projected_keys("covariance", "accumulator")
238
240
 
239
241
  @torch.no_grad
240
242
  def single_tensor_update(self, tensor, param, grad, loss, state, setting):
@@ -38,6 +38,9 @@ class Adam(TensorTransform):
38
38
  self.set_child('exp_avg', exp_avg_tfm)
39
39
  self.set_child('exp_avg_sq', exp_avg_sq_tfm)
40
40
 
41
+ self.add_projected_keys("grad", "exp_avg")
42
+ self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
43
+
41
44
  @torch.no_grad
42
45
  def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
43
46
  self.increment_counter("step", start=0)
@@ -81,4 +84,5 @@ class Adam(TensorTransform):
81
84
  exp_avg = exp_avg * alpha
82
85
 
83
86
  # ---------------------------------- update ---------------------------------- #
84
- return exp_avg / exp_avg_sq.sqrt().add_(eps)
87
+ return exp_avg / exp_avg_sq.sqrt().add_(eps)
88
+
@@ -87,6 +87,9 @@ class Adan(TensorTransform):
87
87
  self.set_child("v", v_tfm)
88
88
  self.set_child("n", n_tfm)
89
89
 
90
+ self.add_projected_keys("grad_sq", "m", "v", "g_prev")
91
+ self.add_projected_keys("grad", "n")
92
+
90
93
  @torch.no_grad
91
94
  def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
92
95
  tensors = TensorList(tensors)