torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ from collections import deque
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from torchzero.core import Chainable, TensorTransform
7
+ from torchzero.linalg import matrix_power_eigh, torch_linalg, orthogonalize, OrthogonalizeMethod, regularize_eigh
8
+ from torchzero.utils import TensorList, vec_to_tensors_
9
+
10
+
11
+ def update_subspace_preconditioner_(
12
+ grad: torch.Tensor, # store grads and basis as vectors for matmul
13
+ basis: torch.Tensor, # ndim, k
14
+ accumulator_: torch.Tensor, # k, k
15
+ beta: float | None,
16
+ ):
17
+ projected = basis.T @ grad # k
18
+ outer = torch.outer(projected, projected)
19
+
20
+ if beta is None: accumulator_.add_(outer)
21
+ else: accumulator_.lerp_(outer, 1-beta)
22
+
23
+ # yeah so I can also run subspace opts in this basis
24
+ def apply_subspace_preconditioner(
25
+ tensor: torch.Tensor,
26
+ basis: torch.Tensor, # ndim, k
27
+ accumulator: torch.Tensor,
28
+ tol: float,
29
+ truncate: int | None,
30
+ damping: float,
31
+ rdamping: float,
32
+ ):
33
+ L, Q = torch_linalg.eigh(accumulator, retry_float64=True)
34
+ L, Q = regularize_eigh(L=L, Q=Q, truncate=truncate, tol=tol, damping=damping, rdamping=rdamping)
35
+
36
+ if L is None or Q is None:
37
+ return tensor.clip(-0.1, 0.1)
38
+
39
+ preconditioner = (Q * L.rsqrt().unsqueeze(-2)) @ Q.mH
40
+
41
+ tensor_projected = basis.T @ tensor # k
42
+ update_projected = preconditioner @ tensor_projected # k
43
+ return basis @ update_projected # d
44
+
45
+
46
+ class CommonDirectionsWhiten(TensorTransform):
47
+ """Whitens in subspace spanned by history of gradient differences.
48
+
49
+ Args:
50
+ beta - for preconditioner itself in the basis.
51
+ basis_beta - how much basis is allowed to change.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ k: int = 100,
57
+ beta: float | None = 0.95,
58
+ basis_beta=0.95,
59
+ tol: float = 1e-7,
60
+ truncate: int | None = None,
61
+ damping: float = 1e-4,
62
+ rdamping: float = 0,
63
+ basis_type: Literal["gradients", "differences"] = "differences",
64
+ orthogonalize_method: OrthogonalizeMethod | None = 'newtonschulz',
65
+
66
+ concat_params: bool = True,
67
+ inner: Chainable | None = None,
68
+ ):
69
+ defaults = locals().copy()
70
+ for key in ["self", "inner", "concat_params"]:
71
+ del defaults[key]
72
+
73
+ super().__init__(defaults, concat_params=concat_params, inner=inner)
74
+
75
+ @torch.no_grad
76
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
77
+ g = tensor.ravel()
78
+ k = setting['k']
79
+ beta = setting['beta']
80
+ basis_beta = setting['basis_beta']
81
+ step = state.get("step", 0)
82
+ state["step"] = step + 1
83
+
84
+ # initialize history
85
+ if 'history' not in state:
86
+ state['history'] = deque(maxlen=k)
87
+ state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
88
+ state['basis'] = torch.zeros(g.numel(), k, device=g.device, dtype=g.dtype)
89
+
90
+ history: deque = state['history']
91
+ accumulator = state['accumulator']
92
+ basis = state['basis']
93
+ history.append(g)
94
+
95
+ # stack history to new basis term, if history isn't full, fill with random vecs
96
+ if len(history) < k:
97
+ basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
98
+ history_basis = torch.stack(tuple(history), -1)
99
+ basis_t[:, -len(history):] = history_basis
100
+
101
+ else:
102
+ basis_t = torch.stack(tuple(history), -1)
103
+
104
+ # in this case basis uses differences in gradients except last entry is the gradient
105
+ if setting["basis_type"] == "differences":
106
+ basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
107
+
108
+ # normalize or orthonormalize new basis term
109
+ if setting["orthogonalize_method"] is not None:
110
+ basis_t = orthogonalize(basis_t, method = setting["orthogonalize_method"])
111
+ else:
112
+ basis_t = (basis_t - basis_t.mean()) / basis_t.std().clip(min=torch.finfo(g.dtype).tiny * 2)
113
+
114
+ # lerp basis
115
+ basis.lerp_(basis_t, 1-basis_beta)
116
+ basis = basis / (1 - basis_beta ** (step+1)) # correct bias on basis EMA
117
+ update_subspace_preconditioner_(g, basis, accumulator, beta)
118
+
119
+ @torch.no_grad
120
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
121
+ g = tensor.ravel()
122
+
123
+ basis = state['basis']
124
+ accumulator = state['accumulator']
125
+ step = state["step"]
126
+ accumulator = accumulator / (1 - setting["beta"] ** (step+1)) # correct bias on accumulator EMA
127
+
128
+ try:
129
+ preconditioned = apply_subspace_preconditioner(
130
+ g,
131
+ basis,
132
+ accumulator,
133
+ tol=setting["tol"],
134
+ truncate=setting["truncate"],
135
+ damping=setting["damping"],
136
+ rdamping=setting["rdamping"],
137
+ )
138
+ except torch.linalg.LinAlgError:
139
+ preconditioned = g.clip(-0.1, 0.1)
140
+
141
+ return preconditioned.view_as(tensor)
142
+
@@ -0,0 +1,160 @@
1
+ from typing import Any, Literal
2
+
3
+ import torch
4
+
5
+ from ...core import TensorTransform
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
+ from ..adaptive.lre_optimizers import LREOptimizerBase, _squared_reproject
8
+
9
+
10
+ def signed_cbrt(x: TensorList | Any) -> Any:
11
+ return x.sign() * x.abs().pow(1/3)
12
+
13
+ def _clip_min_magnitude(x: torch.Tensor, eps: float):
14
+ return x.sign() * x.abs().clamp(min=eps)
15
+
16
+ _cubic_adam_mode = Literal["signed_cbrt", "unsigned_cbrt", "halve"]
17
+
18
+ def _cubic_minimize(A: torch.Tensor | Any, B: torch.Tensor | Any, C: torch.Tensor | Any, eps):
19
+ """minimizes (A/3)x^3 + (A/2)x^2 + Cx"""
20
+ discriminant = B**2 - 4 * A * C
21
+
22
+ denom = _clip_min_magnitude(2 * A, eps)
23
+ root = discriminant.clamp(min=0).sqrt_()
24
+
25
+ x0 = (-B + root) / denom
26
+ x1 = (-B - root) / denom
27
+
28
+ f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
29
+ f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
30
+
31
+ x_star = x0.where(f0 < f1, x1)
32
+
33
+ adam = -C / (B + eps)
34
+ return adam.where(discriminant < 0, x_star)
35
+
36
+ def cubic_adam_(
37
+ tensors: TensorList,
38
+ exp_avg_: TensorList,
39
+ exp_avg_sq_: TensorList,
40
+ exp_avg_cu_: TensorList,
41
+ alpha: float | NumberList,
42
+ beta1: float | NumberList,
43
+ beta2: float | NumberList,
44
+ beta3: float | NumberList,
45
+ eps: float | NumberList,
46
+ debiased: bool,
47
+ step: int,
48
+
49
+ mode: _cubic_adam_mode = 'signed_cbrt'
50
+ ):
51
+ exp_avg_.lerp_(tensors, 1-beta1)
52
+ exp_avg_sq_.lerp_(tensors**2, 1-beta2)
53
+ exp_avg_cu_.lerp_(tensors**3, 1-beta3)
54
+
55
+ if debiased:
56
+ m1 = exp_avg_ / (1 - beta1 ** step)
57
+ m2 = exp_avg_sq_ / (1 - beta2 ** step)
58
+ m3 = exp_avg_cu_ / (1 - beta3 ** step)
59
+ else:
60
+ m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
61
+
62
+ # adam minimizes ax^2 + bx
63
+ # we are going to minimize ax^3 + bx^2 + cx
64
+
65
+ if mode == "signed_cbrt": A = signed_cbrt(m3)
66
+ elif mode == "unsigned_cbrt": A = m3.abs().pow(1/3)
67
+ elif mode == 'halve': A = 0.5 * m3
68
+ else: raise ValueError(mode)
69
+
70
+ B = m2.sqrt()
71
+ C = m1
72
+ x_star = _cubic_minimize(A, B, C, eps)
73
+ return x_star.mul_(-alpha)
74
+
75
+ class CubicAdam(TensorTransform):
76
+ """Adam which has 3rd momentum and minimizes a cubic polynomial."""
77
+ def __init__(
78
+ self,
79
+ beta1: float = 0.9,
80
+ beta2: float = 0.99,
81
+ beta3: float = 0.99,
82
+ eps: float = 1e-8,
83
+ debiased:bool=True,
84
+ alpha: float = 1.,
85
+
86
+ mode: _cubic_adam_mode = 'signed_cbrt'
87
+ ):
88
+ defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha,mode=mode)
89
+ super().__init__(defaults)
90
+
91
+ @torch.no_grad
92
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
93
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
94
+
95
+ beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
96
+ exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
97
+
98
+ return cubic_adam_(
99
+ tensors=TensorList(tensors),
100
+ exp_avg_=exp_avg,
101
+ exp_avg_sq_=exp_avg_sq,
102
+ exp_avg_cu_=exp_avg_cu,
103
+ alpha=alpha,
104
+ beta1=beta1,
105
+ beta2=beta2,
106
+ beta3=beta3,
107
+ eps=eps,
108
+ debiased=settings[0]['debiased'],
109
+ step=step,
110
+
111
+ mode=settings[0]["mode"]
112
+ )
113
+
114
+ class SubspaceCubicAdam(LREOptimizerBase):
115
+ """Runs cubic Adam in low rank eigenbasis."""
116
+ def __init__(self, beta1=0.9, beta2=0.95, beta3=0.95, eps=1e-8, mode: _cubic_adam_mode = 'signed_cbrt', cautious:bool=False, exact_reproject:bool=True):
117
+ self.beta1 = beta1
118
+ self.beta2 = beta2
119
+ self.beta3 = beta3
120
+ self.eps = eps
121
+ self.cautious = cautious
122
+ self.mode: _cubic_adam_mode = mode
123
+ self.exact_reproject = exact_reproject
124
+
125
+ def step(self, g, L, Q, state):
126
+ g = Q.T @ g
127
+
128
+ if "exp_avg" not in state:
129
+ state["exp_avg"] = torch.zeros_like(g)
130
+ state["exp_avg_sq"] = torch.zeros_like(g)
131
+ state["exp_avg_cu"] = torch.zeros_like(g)
132
+ state["current_step"] = 1
133
+
134
+ dir = cubic_adam_(
135
+ tensors = TensorList([g]),
136
+ exp_avg_ = TensorList([state["exp_avg"]]),
137
+ exp_avg_sq_ = TensorList([state["exp_avg_sq"]]),
138
+ exp_avg_cu_ = TensorList([state["exp_avg_cu"]]),
139
+ alpha = 1,
140
+ beta1 = self.beta1,
141
+ beta2 = self.beta2,
142
+ beta3 = self.beta3,
143
+ eps = self.eps,
144
+ debiased = True,
145
+ step = state["current_step"],
146
+
147
+ mode=self.mode,
148
+ )[0]
149
+
150
+ state["current_step"] += 1
151
+ return Q @ dir
152
+
153
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
154
+ if "exp_avg" not in state: return
155
+
156
+ C = Q_new.T @ Q_old
157
+
158
+ state["exp_avg"] = C @ state["exp_avg"]
159
+ state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], exact=self.exact_reproject)
160
+ state["exp_avg_cu"] = C.pow(3) @ state["exp_avg_cu"] # exact reproject with 1_000_000 is feasible
@@ -0,0 +1,182 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...linalg.orthogonalize import orthogonalize, OrthogonalizeMethod
5
+ from ...linalg.eigh import eigh_plus_uuT, regularize_eigh
6
+ from ...utils import TensorList, unpack_states, vec_to_tensors_
7
+ from ..opt_utils import safe_clip
8
+ from .eigengrad import _eigengrad_update_state_, eigengrad_apply
9
+
10
+
11
+ def sr1_u(L: torch.Tensor, Q: torch.Tensor, s:torch.Tensor, y: torch.Tensor, tol:float):
12
+ """u from u u^T correction and its sign"""
13
+ r = y - torch.linalg.multi_dot([Q, L.diag_embed(), Q.T, s]) # pylint:disable=not-callable
14
+ rs = r.dot(s)
15
+
16
+ if rs.abs() < tol * torch.linalg.vector_norm(r) * torch.linalg.vector_norm(s): # pylint:disable=not-callable
17
+ return None, None
18
+
19
+ u = r / rs.abs().sqrt()
20
+ return u, torch.sign(rs)
21
+
22
+ class EigenSR1(Transform):
23
+ def __init__(
24
+ self,
25
+ rank: int = 100,
26
+ tol: float = 1e-32,
27
+ eig_tol: float | None = None,
28
+ damping: float = 0,
29
+ rdamping: float = 0,
30
+ abs: bool = False,
31
+ mm_tol: float = 1e-7,
32
+ mm_truncate: int | None = None,
33
+ mm_damping: float = 1e-4,
34
+ mm_rdamping: float = 0,
35
+ mm_abs: bool = True,
36
+ id_reg: float | None = None,
37
+ column_space_tol=1e-9,
38
+ beta: float = 0.95,
39
+ balance_tol: float = 10,
40
+ balance_strength: float = 1e-1,
41
+
42
+ eigenbasis_optimizer = None,
43
+ update_freq: int = 1,
44
+ init_steps: int = 10,
45
+ orthogonalize_interval: int | None = 1,
46
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
47
+
48
+ hvp_method = "autograd",
49
+ h = 1e-3,
50
+ inner = None,
51
+
52
+ ):
53
+ defaults = locals().copy()
54
+ for k in ["self", "inner"]:
55
+ del defaults[k]
56
+
57
+ super().__init__(defaults)
58
+
59
+ def update_states(self, objective, states, settings):
60
+ fs = settings[0]
61
+ step = self.increment_counter("step", 0)
62
+
63
+ if step % fs["update_freq"] == 0:
64
+
65
+ params = TensorList(objective.params)
66
+
67
+ # compute y as hessian-vector product with s (random vecs during init steps)
68
+ if ("p_prev" not in self.global_state) or (step < fs["init_steps"]):
69
+ s_list = params.sample_like('rademacher')
70
+
71
+ else:
72
+ p_prev = self.global_state["p_prev"]
73
+ s_list = params - p_prev
74
+
75
+ if s_list.dot(s_list) < torch.finfo(s_list[0].dtype).tiny * 2:
76
+ s_list = params.sample_like('rademacher')
77
+
78
+ self.global_state["p_prev"] = params
79
+
80
+ # compute y as hessian-vector product with s
81
+ Hz_list, _ = objective.hessian_vector_product(s_list, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
82
+
83
+ s = torch.cat([t.ravel() for t in s_list])
84
+ y = torch.cat([t.ravel() for t in Hz_list])
85
+
86
+ # keep track of exponential moving average of hessian diagonal and balance eigenvalues
87
+ if (fs["balance_strength"] != 0) and (step > fs["init_steps"]) and ("L" in self.global_state):
88
+
89
+ D = s * y # hutchinson estimator
90
+ exp_avg = self.global_state.get("exp_avg", None)
91
+
92
+ if exp_avg is None:
93
+ exp_avg = self.global_state["exp_avg"] = D
94
+
95
+ else:
96
+ exp_avg.lerp_(D, weight=1-fs["beta"])
97
+
98
+ L = self.global_state["L"]
99
+ L_abs = L.abs()
100
+ tau = L_abs.amax() / exp_avg.abs().amax()
101
+
102
+ if tau > fs["balance_tol"]:
103
+ L_balanced = L_abs.pow((1 / tau) ** (1 / fs["balance_strength"])).copysign(L)
104
+ self.global_state["L"] = torch.where(L_abs > 1, L_balanced, L)
105
+
106
+ # initialize L and Q on 1st step
107
+ if "L" not in self.global_state:
108
+
109
+ L = torch.zeros(1, dtype=s.dtype, device=s.device) # rank, rank
110
+ Q = torch.zeros([s.numel(), 1], dtype=s.dtype, device=s.device) # ndim, rank
111
+
112
+ u, sign = sr1_u(L=L, Q=Q, s=s, y=y, tol=0)
113
+ assert u is not None and sign is not None
114
+
115
+ # for uu^T u is eigenvector and u^T u is eigenvalue
116
+ norm = torch.linalg.vector_norm(u).clip(min=torch.finfo(u.dtype).tiny * 2) # pylint:disable=not-callable
117
+
118
+ self.global_state["L"] = self.global_state["L_reg"] = (u.dot(u).unsqueeze(0) / norm) * sign # (rank,)
119
+ self.global_state["Q"] = self.global_state["Q_reg"] = u.unsqueeze(-1) / norm # (m, rank)
120
+
121
+ # update hessian
122
+ else:
123
+ try:
124
+ L = self.global_state["L"]
125
+ Q = self.global_state["Q"]
126
+
127
+ H_step = self.increment_counter("H_step", start=0)
128
+ if H_step % fs["orthogonalize_interval"] == 0:
129
+ Q = orthogonalize(Q, method=fs["orthogonalize_method"])
130
+
131
+ u, sign = sr1_u(L=L, Q=Q, s=s, y=y, tol=fs["tol"])
132
+
133
+ if (u is not None) and (sign is not None):
134
+
135
+ # compute new factors
136
+ L_new, Q_new = eigh_plus_uuT(L, Q, u, tol=fs["column_space_tol"], alpha=sign.item(), retry_float64=True)
137
+
138
+ # truncate/regularize new factors (those go into the accumulator)
139
+ L_new, Q_new = regularize_eigh(L=L_new, Q=Q_new, truncate=min(fs["rank"], s.numel()),
140
+ tol=fs["eig_tol"], damping=fs["damping"], rdamping=fs["rdamping"])
141
+
142
+ _eigengrad_update_state_(state=self.global_state, setting=fs, L_new=L_new, Q_new=Q_new)
143
+
144
+ except torch.linalg.LinAlgError:
145
+ pass
146
+
147
+
148
+
149
+ def apply_states(self, objective, states, settings):
150
+ fs = settings[0]
151
+ updates = objective.get_updates()
152
+
153
+ if "eigenbasis_state" not in self.global_state:
154
+ self.global_state["eigenbasis_state"] = {}
155
+
156
+ step = self.global_state["step"] # starts at 0
157
+ if step < fs["init_steps"]:
158
+
159
+ # skip update first init_steps to let hessian kick-start
160
+ objective.stop = True
161
+ objective.skip_update = True
162
+ return objective
163
+
164
+ if "L_reg" not in self.global_state:
165
+ TensorList(updates).clip_(-0.1, 0.1)
166
+ return objective
167
+
168
+ dir = eigengrad_apply(
169
+ tensor = torch.cat([t.ravel() for t in updates]),
170
+ L_reg = self.global_state["L_reg"],
171
+ Q_reg = self.global_state["Q_reg"],
172
+ beta = None,
173
+ step = None,
174
+ debias = False,
175
+ id_reg = fs["id_reg"],
176
+ eigenbasis_optimizer = fs["eigenbasis_optimizer"],
177
+ eigenbasis_state = self.global_state["eigenbasis_state"],
178
+ whiten_fn = lambda x: x
179
+ )
180
+
181
+ vec_to_tensors_(dir, updates)
182
+ return objective
@@ -0,0 +1,207 @@
1
+ # pylint: disable = non-ascii-name
2
+ from collections.abc import Mapping
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, TensorTransform
7
+ from ...linalg.eigh import eigh_plus_uuT, regularize_eigh
8
+ from ...linalg.orthogonalize import OrthogonalizeMethod, orthogonalize
9
+ from ...linalg.linear_operator import Eigendecomposition
10
+ from ..adaptive.lre_optimizers import LREOptimizerBase
11
+
12
+
13
+ def _eigengrad_update_state_(state:dict, setting: Mapping, L_new: torch.Tensor | None, Q_new:torch.Tensor | None):
14
+ """stores L, Q, L_reg, Q_reg and reprojects eigenbasis opt (this is also used on other eigen based modules)"""
15
+ if (L_new is not None) and (Q_new is not None):
16
+
17
+ # re-orthogonalize
18
+ orthogonalize_interval = setting["orthogonalize_interval"]
19
+ if orthogonalize_interval is not None:
20
+ Q_step = state.get("Q_step", 0)
21
+ state["Q_step"] = Q_step + 1
22
+ if Q_step % orthogonalize_interval == 0:
23
+ Q_new = orthogonalize(Q_new, method=setting["orthogonalize_method"])
24
+
25
+ # take absolute value (for hessian)
26
+ if setting.get("abs", False):
27
+ L_new = L_new.abs()
28
+
29
+ # store
30
+ state["L"] = L_new
31
+ state["Q"] = Q_new
32
+
33
+ # absolute value for matmul
34
+ if setting.get("mm_abs", False):
35
+ L_new = L_new.abs()
36
+
37
+ # regularize for matmul
38
+ # this second round of regularization is only used for preconditioning
39
+ # and doesn't affect the accumulator
40
+ L_reg_new, Q_reg_new = regularize_eigh(L=L_new, Q=Q_new,
41
+ truncate=setting["mm_truncate"],
42
+ tol=setting["mm_tol"],
43
+ damping=setting["mm_damping"],
44
+ rdamping=setting["mm_rdamping"],
45
+ )
46
+
47
+ # print(f'{state["L_reg"] = }, {L_reg_new = }')
48
+
49
+ # reproject eigenbasis optimizer
50
+ if (L_reg_new is not None) and (Q_reg_new is not None):
51
+ eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
52
+ if eigenbasis_optimizer is not None:
53
+ eigenbasis_optimizer.reproject(L_old=state["L_reg"], Q_old=state["Q_reg"], L_new=L_reg_new,
54
+ Q_new=Q_reg_new, state=state["eigenbasis_state"])
55
+
56
+ state["L_reg"] = L_reg_new
57
+ state["Q_reg"] = Q_reg_new
58
+
59
+
60
+ def eigengrad_apply(
61
+ tensor: torch.Tensor,
62
+ L_reg: torch.Tensor,
63
+ Q_reg: torch.Tensor,
64
+ beta: float | None,
65
+ step: int | None,
66
+ debias: bool,
67
+ id_reg: float | None,
68
+ eigenbasis_optimizer: LREOptimizerBase | None,
69
+ eigenbasis_state: dict,
70
+
71
+ whiten_fn = torch.sqrt
72
+ ):
73
+ # debias
74
+ if debias:
75
+ assert beta is not None and step is not None
76
+ L_reg = L_reg / (1 - beta **step)
77
+
78
+ # step with eigenbasis optimizer
79
+ if eigenbasis_optimizer is not None:
80
+ if (id_reg is not None) and (id_reg != 0):
81
+ raise RuntimeError("id_reg is not compatible with eigenbasis_optimizer")
82
+
83
+ update = eigenbasis_optimizer.step(tensor.ravel(), L=L_reg, Q=Q_reg, state=eigenbasis_state)
84
+ return update.view_as(tensor)
85
+
86
+ # or just whiten
87
+ # L_reg = L_reg.clip(min=torch.finfo(L_reg.dtype).tiny * 2)
88
+
89
+ if id_reg is None or id_reg == 0:
90
+ G = Eigendecomposition(whiten_fn(L_reg), Q_reg, use_nystrom=False)
91
+ dir = G.solve(tensor.ravel())
92
+
93
+ else:
94
+ G = Eigendecomposition(whiten_fn(L_reg), Q_reg, use_nystrom=True)
95
+ dir = G.solve_plus_diag(tensor.ravel(), diag=id_reg)
96
+
97
+ return dir.view_as(tensor)
98
+
99
+ class Eigengrad(TensorTransform):
100
+ """we can easily compute rank 1 symmetric update to a low rank eigendecomposition.
101
+ So this stores covariance matrix as it.
102
+
103
+
104
+ Args:
105
+ rank (int): maximum allowed rank
106
+ beta (float, optional): beta for covariance matrix exponential moving average. Defaults to 0.95.
107
+ eig_tol (float, optional):
108
+ removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
109
+ damping (float, optional):
110
+ added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
111
+ rdamping (float, optional):
112
+ added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
113
+ mm_tol (float, optional):
114
+ removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
115
+ mm_truncate (int | None, optional):
116
+ uses top k eigenvalues to compute the update. Defaults to None.
117
+ mm_damping (float, optional):
118
+ added to eigenvalues when computing the update. Defaults to 1e-4.
119
+ mm_rdamping (float, optional):
120
+ added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
121
+ id_reg (float, optional):
122
+ multiplier to identity matrix added to preconditioner before computing update
123
+ If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
124
+ This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
125
+ column_space_tol (float, optional):
126
+ tolerance for deciding if new eigenvector is within column space of the covariance matrix. Defaults to 1e-9.
127
+ concat_params (bool, optional):
128
+ whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
129
+ update_freq (int, optional): update frequency. Defaults to 1.
130
+ inner (Chainable | None, optional): inner modules. Defaults to None.
131
+
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ rank: int = 100,
137
+ beta=0.95,
138
+ eig_tol: float | None = 1e-5,
139
+ damping: float = 0,
140
+ rdamping: float = 0,
141
+ mm_tol: float = 0,
142
+ mm_truncate: int | None = None,
143
+ mm_damping: float = 1e-4,
144
+ mm_rdamping: float = 0,
145
+ id_reg: float | None = None,
146
+ column_space_tol = 1e-9,
147
+
148
+ orthogonalize_interval: int | None = None,
149
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
150
+
151
+ eigenbasis_optimizer: LREOptimizerBase | None = None,
152
+ concat_params: bool = True,
153
+ update_freq: int = 1,
154
+ inner: Chainable | None = None,
155
+ ):
156
+ defaults = locals().copy()
157
+ for k in ["self", "concat_params", "inner", "update_freq"]:
158
+ del defaults[k]
159
+
160
+ super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
161
+
162
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
163
+ state["step"] = state.get("step", 0) + 1
164
+ beta = setting["beta"]
165
+
166
+ if "L" not in state:
167
+ # for uu^T u is eigenvector and u^T u is eigenvalue
168
+ norm = torch.linalg.vector_norm(tensor).clip(min=torch.finfo(tensor.dtype).tiny * 2) # pylint:disable=not-callable
169
+
170
+ state["L"] = state["L_reg"] = (tensor.dot(tensor).unsqueeze(0) / norm) # (rank,)
171
+ state["Q"] = state["Q_reg"] = tensor.unsqueeze(-1) / norm # (m, rank)
172
+
173
+ else:
174
+ try:
175
+ L = state["L"]
176
+ Q = state["Q"]
177
+
178
+ # compute new factors
179
+ L_new, Q_new = eigh_plus_uuT(L*beta, Q, tensor, alpha=(1-beta), tol=setting["column_space_tol"], retry_float64=True)
180
+
181
+ # truncate/regularize new factors (those go into the accumulator)
182
+ L_new, Q_new = regularize_eigh(L=L_new, Q=Q_new, truncate=setting["rank"], tol=setting["eig_tol"],
183
+ damping=setting["damping"], rdamping=setting["rdamping"])
184
+
185
+ _eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
186
+
187
+ except torch.linalg.LinAlgError:
188
+ pass
189
+
190
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
191
+ if "L_reg" not in state:
192
+ return tensor.clip(-0.1, 0.1)
193
+
194
+ if "eigenbasis_state" not in state:
195
+ state["eigenbasis_state"] = {}
196
+
197
+ return eigengrad_apply(
198
+ tensor = tensor,
199
+ L_reg = state["L_reg"],
200
+ Q_reg = state["Q_reg"],
201
+ beta = setting["beta"],
202
+ step = state["step"],
203
+ debias = True,
204
+ id_reg = setting["id_reg"],
205
+ eigenbasis_optimizer = setting["eigenbasis_optimizer"],
206
+ eigenbasis_state = state["eigenbasis_state"]
207
+ )