torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,116 @@
1
+ # pylint:disable=not-callable
2
+ """all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
3
+ import math
4
+ import warnings
5
+
6
+ import torch
7
+
8
+ from ....core import Chainable, TensorTransform
9
+ from ._psgd_utils import _initialize_lra_state_
10
+ from .psgd import lift2single, precond_grad_lra, update_precond_lra_whiten
11
+
12
+ # matches
13
+ class PSGDLRAWhiten(TensorTransform):
14
+ """Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
15
+
16
+ Args:
17
+ rank (int, optional):
18
+ Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
19
+ init_scale (float | None, optional):
20
+ initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
21
+ lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
22
+ betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
23
+ damping (float, optional):
24
+ adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
25
+ grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
26
+ update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
27
+ concat_params (bool, optional):
28
+ if True, treats all parameters as concatenated to a single vector.
29
+ If False, each parameter is preconditioned separately. Defaults to True.
30
+ inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
31
+
32
+ ###Examples:
33
+
34
+ Pure PSGD LRA:
35
+ ```py
36
+ optimizer = tz.Optimizer(
37
+ model.parameters(),
38
+ tz.m.LRAWhiten(),
39
+ tz.m.LR(1e-3),
40
+ )
41
+ ```
42
+
43
+ Momentum into preconditioner (whitens momentum):
44
+ ```py
45
+ optimizer = tz.Optimizer(
46
+ model.parameters(),
47
+ tz.m.EMA(0.9),
48
+ tz.m.LRAWhiten(),
49
+ tz.m.LR(1e-3),
50
+ )
51
+ ```
52
+
53
+ Updating the preconditioner from gradients and applying it to momentum:
54
+ ```py
55
+ optimizer = tz.Optimizer(
56
+ model.parameters(),
57
+ tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
58
+ tz.m.LR(1e-3),
59
+ )
60
+ ```
61
+
62
+ """
63
+ def __init__(
64
+ self,
65
+ rank: int = 10,
66
+ init_scale: float | None = None,
67
+ lr_preconditioner=0.1,
68
+ betaL=0.9,
69
+ damping=1e-9,
70
+ grad_clip_max_amp=float("inf"),
71
+ update_probability=1.0,
72
+
73
+ concat_params: bool = True,
74
+ inner: Chainable | None = None,
75
+ ):
76
+ defaults = locals().copy()
77
+ del defaults["inner"], defaults["self"]
78
+ super().__init__(defaults, concat_params=concat_params, inner=inner)
79
+
80
+ @torch.no_grad
81
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
82
+ _initialize_lra_state_(tensor, state, setting)
83
+
84
+ @torch.no_grad
85
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
86
+
87
+ g = tensor.ravel().unsqueeze(1) # column vector
88
+
89
+ UVd = state["UVd"]
90
+ if UVd[2] is None: # initialize d on the fly
91
+ UVd[2] = (torch.mean(g**4) + setting["damping"]**4)**(-1/8) * torch.ones_like(g)
92
+
93
+ if torch.rand([]) < setting["update_probability"]: # update preconditioner
94
+ update_precond_lra_whiten(
95
+ UVd=UVd,
96
+ Luvd=state["Luvd"],
97
+ g=g,
98
+ lr=setting["lr_preconditioner"],
99
+ betaL=setting["betaL"],
100
+ damping=setting["damping"],
101
+ )
102
+
103
+ @torch.no_grad
104
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
105
+
106
+ g = tensor.ravel().unsqueeze(1)
107
+ pre_grad = precond_grad_lra(UVd=state["UVd"], g=g)
108
+
109
+ # norm clipping
110
+ grad_clip_max_amp = setting["grad_clip_max_amp"]
111
+ if grad_clip_max_amp < float("inf"): # clip preconditioned gradient
112
+ amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
113
+ if amp > grad_clip_max_amp:
114
+ pre_grad *= grad_clip_max_amp/amp
115
+
116
+ return pre_grad.view_as(tensor)
@@ -304,7 +304,7 @@ class SignConsistencyMask(TensorTransform):
304
304
  GD that skips update for weights where gradient sign changed compared to previous gradient.
305
305
 
306
306
  ```python
307
- opt = tz.Modular(
307
+ opt = tz.Optimizer(
308
308
  model.parameters(),
309
309
  tz.m.Mul(tz.m.SignConsistencyMask()),
310
310
  tz.m.LR(1e-2)
@@ -334,7 +334,7 @@ class SignConsistencyLRs(TensorTransform):
334
334
 
335
335
  ```python
336
336
 
337
- opt = tz.Modular(
337
+ opt = tz.Optimizer(
338
338
  model.parameters(),
339
339
  tz.m.Mul(tz.m.SignConsistencyLRs()),
340
340
  tz.m.LR(1e-2)
@@ -31,7 +31,7 @@ class SAM(Transform):
31
31
  SAM-SGD:
32
32
 
33
33
  ```py
34
- opt = tz.Modular(
34
+ opt = tz.Optimizer(
35
35
  model.parameters(),
36
36
  tz.m.SAM(),
37
37
  tz.m.LR(1e-2)
@@ -41,7 +41,7 @@ class SAM(Transform):
41
41
  SAM-Adam:
42
42
 
43
43
  ```
44
- opt = tz.Modular(
44
+ opt = tz.Optimizer(
45
45
  model.parameters(),
46
46
  tz.m.SAM(),
47
47
  tz.m.Adam(),
@@ -149,7 +149,7 @@ class ASAM(SAM):
149
149
  ASAM-SGD:
150
150
 
151
151
  ```py
152
- opt = tz.Modular(
152
+ opt = tz.Optimizer(
153
153
  model.parameters(),
154
154
  tz.m.ASAM(),
155
155
  tz.m.LR(1e-2)
@@ -159,7 +159,7 @@ class ASAM(SAM):
159
159
  ASAM-Adam:
160
160
 
161
161
  ```
162
- opt = tz.Modular(
162
+ opt = tz.Optimizer(
163
163
  model.parameters(),
164
164
  tz.m.ASAM(),
165
165
  tz.m.Adam(),
@@ -1,4 +1,4 @@
1
- from collections.abc import Sequence
1
+ from collections.abc import Sequence, Iterable
2
2
 
3
3
  import numpy as np
4
4
  import torch
@@ -82,6 +82,31 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
82
82
  tensor = tensor.unflatten(0, flat_sizes)
83
83
  return tensor.permute(*np.argsort(sort_idxs).tolist())
84
84
 
85
+ def diagonal_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor]):
86
+ """computes number of parameters"""
87
+ if isinstance(params, torch.nn.Module): params = params.parameters()
88
+ if isinstance(params, torch.Tensor): params = [params,]
89
+ params = list(params)
90
+ return sum(p.numel() for p in params)
91
+
92
+ def kronecker_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor], merge_small:bool=True, max_dim:int=10_000):
93
+ """computes total size of tensors required to store shampoo preconditioner"""
94
+ if isinstance(params, torch.nn.Module): params = params.parameters()
95
+ if isinstance(params, torch.Tensor): params = [params,]
96
+ params = list(params)
97
+
98
+ memory = 0
99
+ for p in params:
100
+ if merge_small:
101
+ p, _, _ = _merge_small_dims(p, max_dim)
102
+ for dim in p.size():
103
+ if dim > max_dim: memory += dim
104
+ else: memory += dim**2
105
+
106
+ return memory
107
+
108
+
109
+
85
110
 
86
111
  class Shampoo(TensorTransform):
87
112
  """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
@@ -112,7 +137,7 @@ class Shampoo(TensorTransform):
112
137
  Shampoo grafted to Adam
113
138
 
114
139
  ```python
115
- opt = tz.Modular(
140
+ opt = tz.Optimizer(
116
141
  model.parameters(),
117
142
  tz.m.GraftModules(
118
143
  direction = tz.m.Shampoo(),
@@ -125,7 +150,7 @@ class Shampoo(TensorTransform):
125
150
  Adam with Shampoo preconditioner
126
151
 
127
152
  ```python
128
- opt = tz.Modular(
153
+ opt = tz.Optimizer(
129
154
  model.parameters(),
130
155
  tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
131
156
  tz.m.Debias(0.9, 0.999),
@@ -132,7 +132,7 @@ class SOAP(TensorTransform):
132
132
  SOAP:
133
133
 
134
134
  ```python
135
- opt = tz.Modular(
135
+ opt = tz.Optimizer(
136
136
  model.parameters(),
137
137
  tz.m.SOAP(),
138
138
  tz.m.LR(1e-3)
@@ -141,7 +141,7 @@ class SOAP(TensorTransform):
141
141
  Stabilized SOAP:
142
142
 
143
143
  ```python
144
- opt = tz.Modular(
144
+ opt = tz.Optimizer(
145
145
  model.parameters(),
146
146
  tz.m.SOAP(),
147
147
  tz.m.NormalizeByEMA(max_ema_growth=1.2),
@@ -156,7 +156,7 @@ class SOAP(TensorTransform):
156
156
  shampoo_beta: float | None = 0.95,
157
157
  precond_freq: int = 10,
158
158
  merge_small: bool = True,
159
- max_dim: int = 10_000,
159
+ max_dim: int = 4096,
160
160
  precondition_1d: bool = True,
161
161
  eps: float = 1e-8,
162
162
  debias: bool = True,
@@ -50,7 +50,7 @@ class SophiaH(Transform):
50
50
 
51
51
  ```python
52
52
 
53
- opt = tz.Modular(
53
+ opt = tz.Optimizer(
54
54
  model.parameters(),
55
55
  tz.m.SophiaH(),
56
56
  tz.m.LR(0.1)
@@ -63,7 +63,7 @@ class SophiaH(Transform):
63
63
 
64
64
  ```python
65
65
 
66
- opt = tz.Modular(
66
+ opt = tz.Optimizer(
67
67
  model.parameters(),
68
68
  tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
69
69
  tz.m.LR(0.1)
@@ -161,7 +161,7 @@ class ClipValue(TensorTransform):
161
161
 
162
162
  Gradient clipping:
163
163
  ```python
164
- opt = tz.Modular(
164
+ opt = tz.Optimizer(
165
165
  model.parameters(),
166
166
  tz.m.ClipValue(1),
167
167
  tz.m.Adam(),
@@ -171,7 +171,7 @@ class ClipValue(TensorTransform):
171
171
 
172
172
  Update clipping:
173
173
  ```python
174
- opt = tz.Modular(
174
+ opt = tz.Optimizer(
175
175
  model.parameters(),
176
176
  tz.m.Adam(),
177
177
  tz.m.ClipValue(1),
@@ -211,7 +211,7 @@ class ClipNorm(TensorTransform):
211
211
 
212
212
  Gradient norm clipping:
213
213
  ```python
214
- opt = tz.Modular(
214
+ opt = tz.Optimizer(
215
215
  model.parameters(),
216
216
  tz.m.ClipNorm(1),
217
217
  tz.m.Adam(),
@@ -221,7 +221,7 @@ class ClipNorm(TensorTransform):
221
221
 
222
222
  Update norm clipping:
223
223
  ```python
224
- opt = tz.Modular(
224
+ opt = tz.Optimizer(
225
225
  model.parameters(),
226
226
  tz.m.Adam(),
227
227
  tz.m.ClipNorm(1),
@@ -277,7 +277,7 @@ class Normalize(TensorTransform):
277
277
  Examples:
278
278
  Gradient normalization:
279
279
  ```python
280
- opt = tz.Modular(
280
+ opt = tz.Optimizer(
281
281
  model.parameters(),
282
282
  tz.m.Normalize(1),
283
283
  tz.m.Adam(),
@@ -288,7 +288,7 @@ class Normalize(TensorTransform):
288
288
  Update normalization:
289
289
 
290
290
  ```python
291
- opt = tz.Modular(
291
+ opt = tz.Optimizer(
292
292
  model.parameters(),
293
293
  tz.m.Adam(),
294
294
  tz.m.Normalize(1),
@@ -378,7 +378,7 @@ class Centralize(TensorTransform):
378
378
 
379
379
  Standard gradient centralization:
380
380
  ```python
381
- opt = tz.Modular(
381
+ opt = tz.Optimizer(
382
382
  model.parameters(),
383
383
  tz.m.Centralize(dim=0),
384
384
  tz.m.LR(1e-2),
@@ -7,7 +7,7 @@ from ...core import Chainable, TensorTransform
7
7
 
8
8
  from ...utils import TensorList, safe_dict_update_, unpack_dicts, unpack_states
9
9
  from ..quasi_newton.quasi_newton import HessianUpdateStrategy
10
- from ..functional import safe_clip
10
+ from ..opt_utils import safe_clip
11
11
 
12
12
 
13
13
  class ConguateGradientBase(TensorTransform, ABC):
@@ -68,7 +68,7 @@ class ConguateGradientBase(TensorTransform, ABC):
68
68
  self.increment_counter("step", start=0)
69
69
 
70
70
  # initialize on first step
71
- if self.global_state.get('stage', "first step") == "first update":
71
+ if self.global_state.get('stage', "first update") == "first update":
72
72
  g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
73
73
  d_prev.copy_(tensors)
74
74
  g_prev.copy_(tensors)
@@ -1,8 +1,13 @@
1
1
  """Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
2
+ from .adanystrom import AdaNystrom
3
+ from .common_directions_whiten import CommonDirectionsWhiten
2
4
  from .coordinate_momentum import CoordinateMomentum
5
+ from .cubic_adam import CubicAdam, SubspaceCubicAdam
3
6
  from .curveball import CurveBall
7
+ from .eigen_sr1 import EigenSR1
4
8
 
5
9
  # from dct import DCTProjection
10
+ from .eigengrad import Eigengrad
6
11
  from .fft import FFTProjection
7
12
  from .gradmin import GradMin
8
13
  from .higher_order_newton import HigherOrderNewton
@@ -0,0 +1,258 @@
1
+ # pylint: disable = non-ascii-name
2
+ import torch
3
+
4
+ from ...core import Chainable, TensorTransform
5
+ from ...linalg import (
6
+ OrthogonalizeMethod,
7
+ orthogonalize,
8
+ regularize_eigh,
9
+ torch_linalg,
10
+ )
11
+ from ...linalg.linear_operator import Eigendecomposition
12
+ from ..adaptive.lre_optimizers import LREOptimizerBase
13
+ from .eigengrad import _eigengrad_update_state_, eigengrad_apply
14
+
15
+
16
+ def weighted_eigen_plus_rank1_mm(
17
+ # A1 = Q1 @ diag(L1) @ Q1.T
18
+ L1: torch.Tensor,
19
+ Q1: torch.Tensor,
20
+
21
+ # K2 = v2 @ v2.T
22
+ v2: torch.Tensor,
23
+
24
+ # second matrix
25
+ B: torch.Tensor,
26
+
27
+ # weights
28
+ w1: float,
29
+ w2: float,
30
+
31
+ ) -> torch.Tensor:
32
+ """
33
+ Computes ``(w1 * A1 + w2 * A2) @ B``, where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
34
+
35
+ Returns ``(n, k)``
36
+
37
+ Args:
38
+ L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
39
+ Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
40
+ v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)``.
41
+ B (torch.Tensor): shape ``(n, k)``.
42
+ w1 (float): weight for A1.
43
+ w2 (float): weight for A2.
44
+
45
+ """
46
+ # sketch A1
47
+ QTB = Q1.T @ B # (rank, k)
48
+ LQTB = L1.unsqueeze(1) * QTB # (rank, k)
49
+ sketch1 = Q1 @ LQTB # (n, k)
50
+
51
+ # skecth A2
52
+ vB = v2 @ B
53
+ sketch2 = v2.outer(vB)
54
+
55
+ return w1 * sketch1 + w2 * sketch2
56
+
57
+
58
+ def adanystrom_update(
59
+ L1: torch.Tensor,
60
+ Q1: torch.Tensor,
61
+ v2: torch.Tensor,
62
+ w1: float,
63
+ w2: float,
64
+ oversampling_p: int,
65
+ rank: int,
66
+ eig_tol: float,
67
+ damping: float,
68
+ rdamping: float,
69
+ orthogonalize_method: OrthogonalizeMethod,
70
+
71
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
72
+ """computes the Nyström approximation of ``(w1 * A1 + w2 * A2)``,
73
+ where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
74
+
75
+ returns L of shape ``(k, )`` and Q of shape ``(n, k)``.
76
+
77
+ Args:
78
+ L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
79
+ Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
80
+ v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)`` or ``(n, 1)``.
81
+ w1 (float): weight for A1.
82
+ w2 (float): weight for A2.
83
+ """
84
+ n = Q1.shape[0]
85
+ device = Q1.device
86
+ dtype = Q1.dtype
87
+ l = rank + oversampling_p
88
+
89
+ # gaussian test matrix
90
+ Omega = torch.randn(n, l, device=device, dtype=dtype)
91
+
92
+ # sketch
93
+ AOmega = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Omega, w1, w2)
94
+ Q = orthogonalize(AOmega, orthogonalize_method)
95
+
96
+ AQ = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Q, w1, w2)
97
+ QTAQ = Q.T @ AQ
98
+
99
+ W = (QTAQ + QTAQ.T) / 2.0
100
+
101
+ # compute new L and Q
102
+ try:
103
+ L_prime, S = torch_linalg.eigh(W, retry_float64=True)
104
+ except torch.linalg.LinAlgError:
105
+ return L1, Q1
106
+
107
+ L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
108
+
109
+ if L_prime is None or S is None:
110
+ return L1, Q1
111
+
112
+ return L_prime, Q @ S
113
+
114
+
115
+ # def adanystrom_update2(
116
+ # L1: torch.Tensor,
117
+ # Q1: torch.Tensor,
118
+ # v2: torch.Tensor,
119
+ # w1: float,
120
+ # w2: float,
121
+ # rank: int,
122
+ # ):
123
+ # def A_mm(X):
124
+ # return weighted_eigen_plus_rank1_mm(L1=L1, Q1=Q1, v2=v2, B=X, w1=w1, w2=w2)
125
+
126
+ # return nystrom_approximation(A_mm, A_mm=A_mm, ndim=v2.numel(), rank=rank, device=L1.device, dtype=L1.dtype)
127
+
128
+ class AdaNystrom(TensorTransform):
129
+ """Adagrad/RMSprop/Adam with Nyström-approximated covariance matrix.
130
+
131
+ Args:
132
+ rank (_type_): rank of Nyström approximation.
133
+ w1 (float, optional): weight of current covariance matrix. Defaults to 0.95.
134
+ w2 (float, optional): weight of new gradient in covariance matrix. Defaults to 0.05.
135
+ oversampling (int, optional): number of extra random vectors (top rank eigenvalues are kept). Defaults to 10.
136
+ eig_tol (float, optional):
137
+ removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
138
+ damping (float, optional):
139
+ added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
140
+ rdamping (float, optional):
141
+ added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
142
+ mm_tol (float, optional):
143
+ removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
144
+ mm_truncate (int | None, optional):
145
+ uses top k eigenvalues to compute the update. Defaults to None.
146
+ mm_damping (float, optional):
147
+ added to eigenvalues when computing the update. Defaults to 1e-4.
148
+ mm_rdamping (float, optional):
149
+ added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
150
+ id_reg (float, optional):
151
+ multiplier to identity matrix added to preconditioner before computing update
152
+ If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
153
+ This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
154
+ concat_params (bool, optional):
155
+ whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
156
+ update_freq (int, optional): update frequency. Defaults to 1.
157
+ inner (Chainable | None, optional): inner modules. Defaults to None.
158
+ """
159
+ def __init__(
160
+ self,
161
+ rank:int = 100,
162
+ beta=0.95,
163
+ oversampling: int = 10,
164
+ eig_tol: float | None = 1e-32,
165
+ damping: float = 0,
166
+ rdamping: float = 0,
167
+ mm_tol: float = 0,
168
+ mm_truncate: int | None = None,
169
+ mm_damping: float = 0,
170
+ mm_rdamping: float = 0,
171
+ id_reg: float | None = None,
172
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
173
+ eigenbasis_optimizer: LREOptimizerBase | None = None,
174
+ orthogonalize_interval: int | None = 100,
175
+
176
+ concat_params: bool = True,
177
+ update_freq: int = 1,
178
+ inner: Chainable | None = None,
179
+ ):
180
+ defaults = locals().copy()
181
+ for k in ["self", "concat_params", "inner", "update_freq"]:
182
+ del defaults[k]
183
+
184
+ super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
185
+
186
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
187
+ state["step"] = state.get("step", 0) + 1
188
+ rank = setting["rank"]
189
+ device = tensor.device
190
+ dtype = tensor.dtype
191
+ beta = setting["beta"]
192
+
193
+ try:
194
+ if "L" not in state:
195
+ # use just tensor and zero L and Q with zero weight
196
+
197
+ L, Q = adanystrom_update(
198
+ L1=torch.zeros(rank, device=device, dtype=dtype),
199
+ Q1=torch.zeros((tensor.numel(), rank), device=device, dtype=dtype),
200
+ v2=tensor.ravel(),
201
+ w1=0,
202
+ w2=1-beta,
203
+ rank=rank,
204
+ oversampling_p=setting["oversampling"],
205
+ eig_tol=setting["eig_tol"],
206
+ damping=setting["damping"],
207
+ rdamping=setting["rdamping"],
208
+ orthogonalize_method=setting["orthogonalize_method"],
209
+ )
210
+
211
+ state["L"] = state["L_reg"] = L
212
+ state["Q"] = state["Q_reg"] = Q
213
+
214
+ else:
215
+ L = state["L"]
216
+ Q = state["Q"]
217
+
218
+ w1 = beta
219
+ w2 = 1 - w1
220
+
221
+ # compute new factors (this function truncates them)
222
+ L_new, Q_new = adanystrom_update(
223
+ L1=L,
224
+ Q1=Q,
225
+ v2=tensor.ravel(),
226
+ w1=w1,
227
+ w2=w2,
228
+ rank=rank,
229
+ oversampling_p=setting["oversampling"],
230
+ eig_tol=setting["eig_tol"],
231
+ damping=setting["damping"],
232
+ rdamping=setting["rdamping"],
233
+ orthogonalize_method=setting["orthogonalize_method"],
234
+ )
235
+
236
+ _eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
237
+
238
+ except torch.linalg.LinAlgError:
239
+ pass
240
+
241
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
242
+ if "L_reg" not in state:
243
+ return tensor.clip(-0.1, 0.1)
244
+
245
+ if "eigenbasis_state" not in state:
246
+ state["eigenbasis_state"] = {}
247
+
248
+ return eigengrad_apply(
249
+ tensor=tensor,
250
+ L_reg = state["L_reg"],
251
+ Q_reg = state["Q_reg"],
252
+ beta = setting["beta"],
253
+ step = state["step"],
254
+ debias = True,
255
+ id_reg = setting["id_reg"],
256
+ eigenbasis_optimizer = setting["eigenbasis_optimizer"],
257
+ eigenbasis_state = state["eigenbasis_state"]
258
+ )