torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -169,7 +169,7 @@ class FullMatrixAdagrad(TensorTransform):
169
169
  """Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
170
170
 
171
171
  Note:
172
- A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.
172
+ A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.GGT``.
173
173
 
174
174
  Args:
175
175
  reg (float, optional): regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.
@@ -190,7 +190,7 @@ class FullMatrixAdagrad(TensorTransform):
190
190
 
191
191
  Plain full-matrix adagrad
192
192
  ```python
193
- opt = tz.Modular(
193
+ opt = tz.Optimizer(
194
194
  model.parameters(),
195
195
  tz.m.FullMatrixAdagrd(),
196
196
  tz.m.LR(1e-2),
@@ -199,7 +199,7 @@ class FullMatrixAdagrad(TensorTransform):
199
199
 
200
200
  Full-matrix RMSprop
201
201
  ```python
202
- opt = tz.Modular(
202
+ opt = tz.Optimizer(
203
203
  model.parameters(),
204
204
  tz.m.FullMatrixAdagrad(beta=0.99),
205
205
  tz.m.LR(1e-2),
@@ -208,7 +208,7 @@ class FullMatrixAdagrad(TensorTransform):
208
208
 
209
209
  Full-matrix Adam
210
210
  ```python
211
- opt = tz.Modular(
211
+ opt = tz.Optimizer(
212
212
  model.parameters(),
213
213
  tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
214
214
  tz.m.Debias(0.9, 0.999),
@@ -240,22 +240,22 @@ class FullMatrixAdagrad(TensorTransform):
240
240
  def single_tensor_update(self, tensor, param, grad, loss, state, setting):
241
241
 
242
242
  G = tensor.ravel()
243
- GGᵀ = torch.outer(G, G)
243
+ GGT = torch.outer(G, G)
244
244
 
245
245
  # initialize
246
246
  if "accumulator" not in state:
247
247
  init = setting['init']
248
- if init == 'identity': state['accumulator'] = torch.eye(GGᵀ.size(0), device=GGᵀ.device, dtype=GGᵀ.dtype)
249
- elif init == 'zeros': state['accumulator'] = torch.zeros_like(GGᵀ)
250
- elif init == 'GGT': state['accumulator'] = GGᵀ.clone()
248
+ if init == 'identity': state['accumulator'] = torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype)
249
+ elif init == 'zeros': state['accumulator'] = torch.zeros_like(GGT)
250
+ elif init == 'GGT': state['accumulator'] = GGT.clone()
251
251
  else: raise ValueError(init)
252
252
 
253
253
  # update
254
254
  beta = setting['beta']
255
255
  accumulator: torch.Tensor = state["accumulator"]
256
256
 
257
- if beta is None: accumulator.add_(GGᵀ)
258
- else: accumulator.lerp_(GGᵀ, 1-beta)
257
+ if beta is None: accumulator.add_(GGT)
258
+ else: accumulator.lerp_(GGT, 1-beta)
259
259
 
260
260
  # update number of GGᵀ in accumulator for divide
261
261
  state['num_GGTs'] = state.get('num_GGTs', 0) + 1
@@ -86,7 +86,7 @@ class AdaHessian(Transform):
86
86
  Using AdaHessian:
87
87
 
88
88
  ```python
89
- opt = tz.Modular(
89
+ opt = tz.Optimizer(
90
90
  model.parameters(),
91
91
  tz.m.AdaHessian(),
92
92
  tz.m.LR(0.1)
@@ -97,7 +97,7 @@ class AdaHessian(Transform):
97
97
  Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
98
98
  AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
99
99
  ```python
100
- opt = tz.Modular(
100
+ opt = tz.Optimizer(
101
101
  model.parameters(),
102
102
  tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
103
103
  tz.m.LR(0.1)
@@ -2,7 +2,7 @@ import torch
2
2
 
3
3
  from ...core import Chainable, Module, TensorTransform
4
4
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
- from ..functional import debiased_step_size
5
+ from ..opt_utils import debiased_step_size
6
6
 
7
7
 
8
8
  class Adam(TensorTransform):
@@ -60,7 +60,7 @@ class Adan(TensorTransform):
60
60
 
61
61
  Example:
62
62
  ```python
63
- opt = tz.Modular(
63
+ opt = tz.Optimizer(
64
64
  model.parameters(),
65
65
  tz.m.Adan(),
66
66
  tz.m.LR(1e-3),
@@ -30,7 +30,7 @@ class AdaptiveHeavyBall(TensorTransform):
30
30
  """
31
31
  def __init__(self, f_star: float = 0):
32
32
  defaults = dict(f_star=f_star)
33
- super().__init__(defaults, uses_grad=False, uses_loss=True)
33
+ super().__init__(defaults, uses_loss=True)
34
34
 
35
35
  @torch.no_grad
36
36
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -48,7 +48,7 @@ class ESGD(Transform):
48
48
  Using ESGD:
49
49
  ```python
50
50
 
51
- opt = tz.Modular(
51
+ opt = tz.Optimizer(
52
52
  model.parameters(),
53
53
  tz.m.ESGD(),
54
54
  tz.m.LR(0.1)
@@ -59,7 +59,7 @@ class ESGD(Transform):
59
59
  ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
60
60
 
61
61
  ```python
62
- opt = tz.Modular(
62
+ opt = tz.Optimizer(
63
63
  model.parameters(),
64
64
  tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
65
65
  tz.m.LR(0.1)
@@ -0,0 +1,186 @@
1
+ from collections import deque
2
+ from typing import Literal, Any
3
+ import warnings
4
+
5
+ import torch
6
+ from ...core import Chainable, TensorTransform
7
+ from ...linalg import torch_linalg, regularize_eigh
8
+ from .lre_optimizers import LREOptimizerBase
9
+
10
+ def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol):
11
+ """returns U ``(ndim, rank)``, L ``(rank, )``"""
12
+ if isinstance(history, torch.Tensor):
13
+ M = history
14
+ else:
15
+ M = torch.stack(tuple(history), dim=1)# / len(history)
16
+
17
+ MtM = M.T @ M
18
+ if damping != 0:
19
+ MtM.add_(torch.eye(MtM.size(0), device=MtM.device, dtype=MtM.dtype).mul_(damping))
20
+
21
+ try:
22
+ L, Q = torch_linalg.eigh(MtM, retry_float64=True)
23
+
24
+ # damping is already added to MTM, rdamping is added afterwards
25
+ L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eig_tol, damping=0, rdamping=0)
26
+
27
+ if L is None or Q is None: # this means there are no finite eigenvalues
28
+ return None, None
29
+
30
+ U = (M @ Q) * L.rsqrt()
31
+
32
+ # this damping is added after computing U, this is why I didn't use one in linalg.regularize_eig
33
+ # that's because we damp singular values this way
34
+ if rdamping != 0:
35
+ L.add_(rdamping * L[-1]) # L is sorted in ascending order
36
+
37
+ return L, U
38
+
39
+ except torch.linalg.LinAlgError:
40
+ return None, None
41
+
42
+
43
+ class GGT(TensorTransform):
44
+ """
45
+ GGT method from https://arxiv.org/pdf/1806.02958
46
+
47
+ The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
48
+ But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
49
+
50
+ This is equivalent to full-matrix Adagrad on recent gradients.
51
+
52
+ Args:
53
+ history_size (int, optional): number of past gradients to store. Defaults to 10.
54
+ beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
55
+ update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
56
+ eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
57
+ truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
58
+ damping (float, optional): damping value. Defaults to 1e-4.
59
+ rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
60
+ concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
61
+ inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
62
+
63
+ ## Examples:
64
+
65
+ Limited-memory Adagrad
66
+
67
+ ```python
68
+ optimizer = tz.Optimizer(
69
+ model.parameters(),
70
+ tz.m.GGT(),
71
+ tz.m.LR(0.1)
72
+ )
73
+ ```
74
+ Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
75
+
76
+ ```python
77
+ optimizer = tz.Optimizer(
78
+ model.parameters(),
79
+ tz.m.GGT(inner=tz.m.EMA()),
80
+ tz.m.Debias(0.9, 0.999),
81
+ tz.m.LR(0.01)
82
+ )
83
+ ```
84
+
85
+ Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
86
+
87
+ ```python
88
+ optimizer = tz.Optimizer(
89
+ model.parameters(),
90
+ tz.m.GGT(inner=tz.m.EMA()),
91
+ tz.m.Debias(0.9, 0.999),
92
+ tz.m.ClipNormByEMA(max_ema_growth=1.2),
93
+ tz.m.LR(0.01)
94
+ )
95
+ ```
96
+ Reference:
97
+ Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ history_size: int = 100,
103
+ update_freq: int = 1,
104
+ eig_tol: float = 1e-7,
105
+ truncate: int | None = None,
106
+ damping: float = 1e-4,
107
+ rdamping: float = 0,
108
+ eigenbasis_optimizer: LREOptimizerBase | None = None,
109
+ concat_params: bool = True,
110
+
111
+ inner: Chainable | None = None,
112
+ ):
113
+ defaults = locals().copy()
114
+ del defaults['self'], defaults['inner'], defaults['concat_params']
115
+
116
+ super().__init__(defaults, concat_params=concat_params, inner=inner)
117
+
118
+ @torch.no_grad
119
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
120
+ history_size = setting['history_size']
121
+ update_freq = setting['update_freq']
122
+
123
+ if 'history' not in state: state['history'] = deque(maxlen=history_size)
124
+ history = state['history']
125
+
126
+ t = tensor.clone().view(-1)
127
+ history.append(t)
128
+
129
+ step = state.get('step', 0)
130
+ state['step'] = step + 1
131
+
132
+ if step % update_freq == 0 :
133
+
134
+ # compute new factors
135
+ L = state.get("L", None)
136
+ U = state.get("U", None)
137
+
138
+ L_new, U_new = ggt_update(
139
+ history,
140
+ damping=setting["damping"],
141
+ rdamping=setting["rdamping"],
142
+ truncate=setting["truncate"],
143
+ eig_tol=setting["eig_tol"],
144
+ )
145
+
146
+ # reproject eigenbasis optimizer
147
+ eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
148
+ if eigenbasis_optimizer is not None:
149
+ if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
150
+ eigenbasis_state = state["eigenbasis_state"]
151
+ eigenbasis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=eigenbasis_state)
152
+
153
+
154
+ # store new factors
155
+ if L_new is not None: state["L"] = L_new
156
+ if U_new is not None: state["U"] = U_new
157
+
158
+
159
+ @torch.no_grad
160
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
161
+ g = tensor.view(-1)
162
+ U = state.get('U', None)
163
+
164
+ if U is None:
165
+ # fallback to element-wise preconditioning
166
+ history = torch.stack(tuple(state["history"]), 0)
167
+ g /= history.square().mean(0).sqrt().add(1e-8)
168
+ return g.view_as(tensor)
169
+
170
+ L = state['L']
171
+
172
+ # step with eigenbasis optimizer
173
+ eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
174
+ if eigenbasis_optimizer is not None:
175
+
176
+ if "eigenbasis_state" not in state: state["eigenbasis_state"] = {}
177
+ eigenbasis_state = state["eigenbasis_state"]
178
+
179
+ update = eigenbasis_optimizer.step(g, L=L, Q=U, state=eigenbasis_state)
180
+ return update.view_as(tensor)
181
+
182
+ # or just whiten
183
+ z = U.T @ g
184
+ update = (U * L.rsqrt()) @ z
185
+ return update.view_as(tensor)
186
+
@@ -1,10 +1,11 @@
1
+ from typing import Any
1
2
  import torch
2
3
 
3
4
  from ...core import TensorTransform
4
5
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
6
 
6
7
 
7
- def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
8
+ def lion_(tensors: TensorList | Any, exp_avg_: TensorList | Any, beta1, beta2,):
8
9
  update = exp_avg_.lerp(tensors, 1-beta1).sign_()
9
10
  exp_avg_.lerp_(tensors, 1-beta2)
10
11
  return update
@@ -0,0 +1,299 @@
1
+ """subspace optimizers to be used in a low rank eigenbasis
2
+
3
+ three opts support this - GGT and experimental AdaNystrom and Eigengrad
4
+
5
+ I could define repoject on a module but because most opts use per-parameter state that is complicated"""
6
+
7
+ import math
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, cast
10
+
11
+ import torch
12
+
13
+ from ...linalg import matrix_power_eigh, torch_linalg
14
+ from .lion import lion_
15
+
16
+ class LREOptimizerBase(ABC):
17
+ """Optimizer to run in a low rank eigenbasis.
18
+
19
+ notes:
20
+
21
+ 1. it shouldn't store any states in self, everything should be in state.
22
+ This is because this may be called on multiple parameters in a sequence
23
+
24
+ 2. apply is always called first, than reproject whenever eigenbasis gets updated
25
+
26
+ 3. L is variance in the eigenbasis.
27
+ """
28
+ @abstractmethod
29
+ def step(self, g: torch.Tensor, L: torch.Tensor, Q: torch.Tensor, state: dict) -> torch.Tensor:
30
+ ...
31
+
32
+ @abstractmethod
33
+ def reproject(self, L_old: torch.Tensor, Q_old: torch.Tensor,
34
+ L_new: torch.Tensor, Q_new: torch.Tensor, state: dict) -> None:
35
+ ...
36
+
37
+ class Whiten(LREOptimizerBase):
38
+ """This simply applies whitening and is equivalent to not running an optimizer in the eigenbasis"""
39
+ def step(self, g, L, Q, state): return (Q * L.rsqrt()) @ (Q.T @ g)
40
+ def reproject(self, L_old, Q_old, L_new, Q_new, state): pass
41
+
42
+ class EMA(LREOptimizerBase):
43
+ """Maintains exponential moving average of gradients in the low rank eigenbasis. Nesterov setting is experimental"""
44
+ def __init__(self, beta=0.9, nesterov:bool=False, cautious:bool=False, whiten:bool=True):
45
+ self.beta = beta
46
+ self.nesterov = nesterov
47
+ self.whiten = whiten
48
+ self.cautious = cautious
49
+
50
+ def step(self, g, L, Q, state):
51
+ g = Q.T @ g
52
+
53
+ if "exp_avg" not in state:
54
+ state["exp_avg"] = torch.zeros_like(g)
55
+
56
+ exp_avg = state["exp_avg"]
57
+ exp_avg.lerp_(g, 1-self.beta)
58
+
59
+ if self.nesterov:
60
+ dir = (g + exp_avg * self.beta) / (1 + self.beta)
61
+ else:
62
+ dir = exp_avg
63
+
64
+ if self.cautious:
65
+ mask = (g * dir) > 0
66
+ dir *= mask
67
+
68
+ if self.whiten: return (Q * L.rsqrt()) @ dir
69
+ return Q @ dir
70
+
71
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
72
+ if "exp_avg" not in state: return
73
+ C = Q_new.T @ Q_old
74
+ state["exp_avg"] = C @ state["exp_avg"]
75
+
76
+
77
+ def adam(g:torch.Tensor, state:dict, beta1, beta2, eps):
78
+
79
+ if "exp_avg" not in state:
80
+ state["exp_avg"] = torch.zeros_like(g)
81
+ state["exp_avg_sq"] = torch.zeros_like(g)
82
+ state["current_step"] = 1
83
+
84
+ exp_avg = state["exp_avg"]
85
+ exp_avg_sq = state["exp_avg_sq"]
86
+ current_step = state["current_step"]
87
+
88
+ exp_avg.lerp_(g, 1-beta1)
89
+ exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1-beta2)
90
+ denom = exp_avg_sq.sqrt().add_(eps)
91
+
92
+ bias_correction1 = 1.0 - (beta1 ** current_step)
93
+ bias_correction2 = 1.0 - (beta2 ** current_step)
94
+ alpha = math.sqrt(bias_correction2) / bias_correction1
95
+ state["current_step"] = current_step + 1
96
+
97
+ return (exp_avg * alpha) / denom
98
+
99
+ def _squared_reproject(C: torch.Tensor, sq: torch.Tensor, exact: bool):
100
+ if exact:
101
+ return (C @ sq.diag_embed() @ C.T).diagonal()
102
+
103
+ return C.square() @ sq
104
+
105
+ class Adam(LREOptimizerBase):
106
+ """Runs Adam in low rank eigenbasis."""
107
+ def __init__(self, beta1=0.9, beta2=0.95, cautious:bool=False, eps=1e-8, exact_reproject:bool=True):
108
+ self.beta1 = beta1
109
+ self.beta2 = beta2
110
+ self.eps = eps
111
+ self.cautious = cautious
112
+ self.exact_reproject = exact_reproject
113
+
114
+ def step(self, g, L, Q, state):
115
+ g = Q.T @ g
116
+
117
+ dir = adam(g, state, self.beta1, self.beta2, self.eps)
118
+
119
+ if self.cautious:
120
+ mask = (g * dir) > 0
121
+ dir *= mask
122
+
123
+ return Q @ dir
124
+
125
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
126
+ if "exp_avg" not in state: return
127
+ C = Q_new.T @ Q_old
128
+
129
+ state["exp_avg"] = C @ state["exp_avg"]
130
+ state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
131
+
132
+
133
+ class FullMatrixAdam(LREOptimizerBase):
134
+ """Runs full-matrix Adam in low rank eigenbasis.
135
+ The preconditioner is updated whenever basis is updated"""
136
+ def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, matrix_power=-1/2, abs=True, cautious:bool=False):
137
+ self.beta1 = beta1
138
+ self.beta2 = beta2
139
+ self.eps = eps
140
+ self.matrix_power = matrix_power
141
+ self.abs = abs
142
+ self.cautious = cautious
143
+
144
+ def step(self, g, L, Q, state):
145
+ g = Q.T @ g
146
+
147
+ # initialize
148
+ if "exp_avg" not in state:
149
+ state["exp_avg"] = torch.zeros_like(g)
150
+ state["covariance"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
151
+ state["preconditioner"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
152
+ state["reprojected"] = True
153
+ state["current_step"] = 1
154
+
155
+ exp_avg = state["exp_avg"]
156
+ covariance = state["covariance"]
157
+ current_step = state["current_step"]
158
+
159
+ # update buffers
160
+ exp_avg.lerp_(g, 1-self.beta1)
161
+ covariance.lerp_(g.outer(g), weight=1-self.beta2)
162
+
163
+ # correct bias
164
+ bias_correction1 = 1.0 - (self.beta1 ** current_step)
165
+ exp_avg = exp_avg / bias_correction1
166
+
167
+ # after reprojecting update the preconditioner
168
+ if state["reprojected"]:
169
+ state["reprojected"] = False
170
+
171
+ bias_correction2 = 1.0 - (self.beta2 ** current_step)
172
+ covariance = covariance / bias_correction2
173
+
174
+ reg = torch.eye(covariance.size(0), device=covariance.device, dtype=covariance.dtype).mul_(self.eps)
175
+ covariance = covariance + reg
176
+
177
+ # compute matrix power
178
+ try:
179
+ state["preconditioner"] = matrix_power_eigh(covariance, self.matrix_power, abs=self.abs)
180
+
181
+ except torch.linalg.LinAlgError:
182
+
183
+ # fallback to diagonal
184
+ state["preconditioner"] = covariance.diagonal().rsqrt().diag_embed()
185
+
186
+ # compute the update
187
+ state["current_step"] = current_step + 1
188
+ preconditioner = state["preconditioner"]
189
+ dir = preconditioner @ exp_avg
190
+
191
+ if self.cautious:
192
+ mask = (g * dir) > 0
193
+ dir *= mask
194
+
195
+ return Q @ dir
196
+
197
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
198
+ if "exp_avg" not in state: return
199
+
200
+ state["reprojected"] = True
201
+
202
+ C = Q_new.T @ Q_old
203
+ state["exp_avg"] = C @ state["exp_avg"]
204
+ state["covariance"] = C @ state["covariance"] @ C.T
205
+
206
+ class Lion(LREOptimizerBase):
207
+ """Runs Lion in the low rank eigenbasis."""
208
+ def __init__(self, beta1=0.9, beta2=0.99, cautious:bool=False):
209
+ self.beta1 = beta1
210
+ self.beta2 = beta2
211
+ self.cautious = cautious
212
+
213
+ def step(self, g, L, Q, state):
214
+ g = Q.T @ g
215
+
216
+ if "exp_avg" not in state:
217
+ state["exp_avg"] = torch.zeros_like(g)
218
+
219
+ dir = cast(torch.Tensor, lion_(g, state["exp_avg"], beta1=self.beta1, beta2=self.beta2))
220
+
221
+ if self.cautious:
222
+ mask = (g * dir) > 0
223
+ dir *= mask
224
+
225
+ return Q @ dir
226
+
227
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
228
+ if "exp_avg" not in state: return
229
+ C = Q_new.T @ Q_old
230
+ state["exp_avg"] = C @ state["exp_avg"]
231
+
232
+
233
+ class Grams(LREOptimizerBase):
234
+ """Runs Grams in low rank eigenbasis."""
235
+ def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, exact_reproject=True):
236
+ self.beta1 = beta1
237
+ self.beta2 = beta2
238
+ self.eps = eps
239
+ self.exact_reproject = exact_reproject
240
+
241
+ def step(self, g, L, Q, state):
242
+ g = Q.T @ g
243
+ dir = adam(g, state, self.beta1, self.beta2, self.eps)
244
+ return Q @ dir.copysign(g)
245
+
246
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
247
+ if "exp_avg" not in state: return
248
+ C = Q_new.T @ Q_old
249
+
250
+ state["exp_avg"] = C @ state["exp_avg"]
251
+ state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
252
+
253
+
254
+ class LaProp(LREOptimizerBase):
255
+ """Runs LaProp in low rank eigenbasis."""
256
+ def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, cautious:bool=False, exact_reproject=True):
257
+ self.beta1 = beta1
258
+ self.beta2 = beta2
259
+ self.eps = eps
260
+ self.cautious = cautious
261
+ self.exact_reproject = exact_reproject
262
+
263
+ def step(self, g, L, Q, state):
264
+ g = Q.T @ g
265
+
266
+ if "exp_avg" not in state:
267
+ state["exp_avg"] = torch.zeros_like(g)
268
+ state["exp_avg_sq"] = torch.zeros_like(g)
269
+ state["current_step"] = 1
270
+
271
+ exp_avg = state["exp_avg"]
272
+ exp_avg_sq = state["exp_avg_sq"]
273
+ current_step = state["current_step"]
274
+
275
+ # update second moments
276
+ exp_avg_sq.mul_(self.beta2).addcmul_(g, g, value=1-self.beta2)
277
+ bias_correction2 = 1.0 - (self.beta2 ** current_step)
278
+
279
+ # divide by bias corrected second moments
280
+ dir = g / (exp_avg_sq / bias_correction2).sqrt().add_(self.eps)
281
+
282
+ # update first moments and bias correct
283
+ exp_avg.lerp_(dir, 1-self.beta1)
284
+ bias_correction1 = 1.0 - (self.beta1 ** current_step)
285
+ dir = exp_avg / bias_correction1
286
+
287
+ if self.cautious:
288
+ mask = (g * dir) > 0
289
+ dir *= mask
290
+
291
+ state["current_step"] = current_step + 1
292
+ return Q @ dir
293
+
294
+ def reproject(self, L_old, Q_old, L_new, Q_new, state):
295
+ if "exp_avg" not in state: return
296
+ C = Q_new.T @ Q_old
297
+
298
+ state["exp_avg"] = C @ state["exp_avg"]
299
+ state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
@@ -35,7 +35,7 @@ class MARSCorrection(TensorTransform):
35
35
 
36
36
  Mars-AdamW
37
37
  ```python
38
- optimizer = tz.Modular(
38
+ optimizer = tz.Optimizer(
39
39
  model.parameters(),
40
40
  tz.m.MARSCorrection(beta=0.95),
41
41
  tz.m.Adam(beta1=0.95, beta2=0.99),
@@ -46,7 +46,7 @@ class MARSCorrection(TensorTransform):
46
46
 
47
47
  Mars-Lion
48
48
  ```python
49
- optimizer = tz.Modular(
49
+ optimizer = tz.Optimizer(
50
50
  model.parameters(),
51
51
  tz.m.MARSCorrection(beta=0.9),
52
52
  tz.m.Lion(beta1=0.9),
@@ -4,7 +4,7 @@ import torch
4
4
 
5
5
  from ...core import Chainable, Transform, HVPMethod
6
6
  from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
7
- from ..functional import initial_step_size
7
+ from ..opt_utils import initial_step_size
8
8
 
9
9
 
10
10
  class MatrixMomentum(Transform):
@@ -4,7 +4,7 @@ import torch
4
4
 
5
5
  from ...core import Chainable, Module, Transform, TensorTransform, step, Objective
6
6
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
7
- from ..functional import ema_
7
+ from ..opt_utils import ema_
8
8
  from ..momentum.momentum import nag_
9
9
 
10
10
 
@@ -99,7 +99,7 @@ class MSAMMomentum(TensorTransform):
99
99
 
100
100
  ```python
101
101
 
102
- opt = tz.Modular(
102
+ opt = tz.Optimizer(
103
103
  model.parameters(),
104
104
  tz.m.MSAM(1e-3)
105
105
  )
@@ -109,7 +109,7 @@ class MSAMMomentum(TensorTransform):
109
109
  To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.
110
110
 
111
111
  ```python
112
- opt = tz.Modular(
112
+ opt = tz.Optimizer(
113
113
  model.parameters(),
114
114
  tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
115
115
  tz.m.Debias(0.9, 0.999),
@@ -166,7 +166,7 @@ class MSAM(Transform):
166
166
  AdamW-MSAM
167
167
 
168
168
  ```py
169
- opt = tz.Modular(
169
+ opt = tz.Optimizer(
170
170
  bench.parameters(),
171
171
  tz.m.MSAMObjective(
172
172
  [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],